MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go WaitGroup的并发安全保障

2022-03-097.9k 阅读

Go语言并发编程基础

在深入探讨 WaitGroup 的并发安全保障之前,我们先来回顾一下Go语言并发编程的一些基础知识。Go语言在设计之初就将并发编程作为其核心特性之一,通过 goroutinechannel 这两个强大的工具,开发者可以轻松地编写高效的并发程序。

goroutine

goroutine 是Go语言中实现并发的轻量级线程。与传统线程相比,goroutine 的创建和销毁开销非常小,可以在一台机器上轻松创建数以万计的 goroutine。创建一个 goroutine 非常简单,只需要在函数调用前加上 go 关键字即可。

package main

import (
    "fmt"
    "time"
)

func worker() {
    fmt.Println("Worker started")
    time.Sleep(2 * time.Second)
    fmt.Println("Worker finished")
}

func main() {
    go worker()
    time.Sleep(3 * time.Second)
    fmt.Println("Main finished")
}

在上述代码中,go worker() 启动了一个新的 goroutine 来执行 worker 函数。主 goroutine 会继续执行后续代码,而不会等待 worker 函数执行完毕。通过 time.Sleep 函数,我们可以让主 goroutine 等待一段时间,以确保 worker goroutine 有足够的时间执行。

channel

channel 是Go语言中用于 goroutine 之间通信和同步的机制。它可以被看作是一个类型化的管道,数据可以从一端发送,从另一端接收。channel 的使用可以有效地避免共享数据带来的竞态条件问题。

package main

import (
    "fmt"
)

func sender(ch chan int) {
    for i := 0; i < 5; i++ {
        ch <- i
    }
    close(ch)
}

func receiver(ch chan int) {
    for num := range ch {
        fmt.Println("Received:", num)
    }
}

func main() {
    ch := make(chan int)
    go sender(ch)
    receiver(ch)
    fmt.Println("Main finished")
}

在这个例子中,sender goroutine 通过 ch <- i 将数据发送到 channel 中,receiver goroutine 使用 for... range 循环从 channel 中接收数据。当 sender goroutine 关闭 channel 时,receiver goroutinefor... range 循环会自动结束。

并发安全问题

在并发编程中,当多个 goroutine 同时访问和修改共享资源时,就可能会出现并发安全问题,其中最典型的就是竞态条件(race condition)。

竞态条件示例

package main

import (
    "fmt"
    "sync"
)

var counter int

func increment(wg *sync.WaitGroup) {
    defer wg.Done()
    for i := 0; i < 1000; i++ {
        counter++
    }
}

func main() {
    var wg sync.WaitGroup
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go increment(&wg)
    }
    wg.Wait()
    fmt.Println("Final counter value:", counter)
}

在上述代码中,我们启动了10个 goroutine 来对 counter 变量进行递增操作。每个 goroutine 会对 counter 进行1000次递增。理想情况下,最终 counter 的值应该是10000。但是,由于多个 goroutine 同时访问和修改 counter,会导致竞态条件,最终 counter 的值往往小于10000。

问题根源

counter++ 看似是一个原子操作,但实际上它包含了读取 counter 的值、增加1、再写回 counter 这三个步骤。在并发环境下,当一个 goroutine 读取了 counter 的值,还未完成增加和写回操作时,另一个 goroutine 也读取了 counter 的值,这样就会导致两次增加操作只生效了一次,从而产生数据不一致的问题。

WaitGroup简介

WaitGroup 是Go语言标准库 sync 包中的一个类型,它用于协调多个 goroutine 的同步。WaitGroup 可以等待一组 goroutine 全部完成任务后再继续执行后续代码。

WaitGroup的基本方法

  • Add(delta int):将 WaitGroup 的计数器增加 delta。如果 delta 为负数,会导致计数器减少。通常在启动 goroutine 前调用 Add 方法,传入需要等待的 goroutine 的数量。
  • Done():将 WaitGroup 的计数器减1。通常在 goroutine 完成任务后调用 Done 方法。它等同于 Add(-1)
  • Wait():阻塞当前 goroutine,直到 WaitGroup 的计数器变为0。

WaitGroup的内部实现原理

WaitGroup 的实现基于Go语言的运行时调度器和信号量机制。

数据结构

WaitGroup 在Go语言的标准库中定义如下:

// $GOROOT/src/sync/waitgroup.go
type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

其中,state1 数组用于存储计数器的值和等待队列的信息。前两个 uint32 用于存储计数器和等待队列的长度,第三个 uint32 用于存储等待队列的指针。

实现细节

  1. Add 方法Add 方法通过原子操作(atomic.AddUint64)来增加计数器的值。如果计数器的值变为负数,会导致程序崩溃,因为这表示 Add 调用的次数超过了预期。
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if v > 0 || w == 0 {
        return
    }
    // 计数器变为0,唤醒所有等待的goroutine
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}
  1. Done 方法Done 方法实际上是 Add(-1) 的别名,它同样通过原子操作减少计数器的值。
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}
  1. Wait 方法Wait 方法会检查计数器是否为0。如果不为0,会将当前 goroutine 放入等待队列,并通过信号量机制(runtime_Semacquire)阻塞当前 goroutine。当计数器变为0时,会唤醒等待队列中的所有 goroutine
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            return
        }
        // 将当前goroutine放入等待队列并阻塞
        runtime_Semacquire(semap)
    }
}

WaitGroup保障并发安全的方式

WaitGroup 本身并不直接处理共享资源的并发访问,但它通过协调 goroutine 的同步,间接地保障了并发安全。

避免竞态条件

回到之前的 counter 示例,我们可以通过 WaitGroup 确保所有 goroutine 完成 counter 的递增操作后再输出结果。虽然这并不能解决 counter 本身的竞态条件问题,但可以让我们在所有操作完成后得到一个确定的结果。

package main

import (
    "fmt"
    "sync"
)

var counter int

func increment(wg *sync.WaitGroup) {
    defer wg.Done()
    for i := 0; i < 1000; i++ {
        counter++
    }
}

func main() {
    var wg sync.WaitGroup
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go increment(&wg)
    }
    wg.Wait()
    fmt.Println("Final counter value:", counter)
}

在这个例子中,wg.Wait() 会阻塞主 goroutine,直到所有10个 goroutine 完成 counter 的递增操作。这样,我们在输出 counter 的值时,所有的操作都已经完成,虽然 counter 的递增过程中存在竞态条件,但最终的输出结果是所有操作完成后的结果。

与其他同步机制结合

为了真正解决 counter 的竞态条件问题,我们可以结合 WaitGroupsync.Mutexsync.Mutex 用于保护共享资源 counter,确保同一时间只有一个 goroutine 可以访问和修改 counter

package main

import (
    "fmt"
    "sync"
)

var counter int
var mu sync.Mutex

func increment(wg *sync.WaitGroup) {
    defer wg.Done()
    for i := 0; i < 1000; i++ {
        mu.Lock()
        counter++
        mu.Unlock()
    }
}

func main() {
    var wg sync.WaitGroup
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go increment(&wg)
    }
    wg.Wait()
    fmt.Println("Final counter value:", counter)
}

在这个改进后的代码中,mu.Lock()mu.Unlock() 确保了 counter++ 操作的原子性,避免了竞态条件。WaitGroup 则负责等待所有 goroutine 完成任务,从而保证了整个程序的并发安全性。

WaitGroup在复杂场景中的应用

在实际的开发中,WaitGroup 常常应用于复杂的并发场景,比如分布式系统中的任务调度、Web服务器的请求处理等。

分布式任务调度示例

假设我们有一个分布式系统,需要在多个节点上执行相同的任务,并等待所有任务完成后进行汇总。

package main

import (
    "fmt"
    "sync"
    "time"
)

func task(nodeID int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Node %d started task\n", nodeID)
    time.Sleep(time.Duration(nodeID) * time.Second)
    fmt.Printf("Node %d finished task\n", nodeID)
}

func main() {
    var wg sync.WaitGroup
    nodes := []int{1, 2, 3, 4, 5}
    for _, node := range nodes {
        wg.Add(1)
        go task(node, &wg)
    }
    wg.Wait()
    fmt.Println("All tasks finished, starting result aggregation")
}

在这个示例中,每个节点模拟一个分布式任务,WaitGroup 用于等待所有节点的任务完成,然后再进行结果汇总(这里只是简单输出提示信息)。

Web服务器请求处理示例

在Web服务器中,我们可能需要并发处理多个请求,并在所有请求处理完成后进行一些清理工作。

package main

import (
    "fmt"
    "net/http"
    "sync"
)

func handleRequest(w http.ResponseWriter, r *http.Request, wg *sync.WaitGroup) {
    defer wg.Done()
    // 处理请求逻辑
    fmt.Fprintf(w, "Request handled")
}

func main() {
    var wg sync.WaitGroup
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        wg.Add(1)
        go handleRequest(w, r, &wg)
    })
    go func() {
        // 模拟一段时间后关闭服务器
        time.Sleep(5 * time.Second)
        fmt.Println("Shutting down server, waiting for requests to finish")
        wg.Wait()
        fmt.Println("All requests finished, server shutdown complete")
    }()
    http.ListenAndServe(":8080", nil)
}

在这个示例中,每个HTTP请求都会启动一个新的 goroutine 来处理,WaitGroup 用于等待所有请求处理完成后再关闭服务器,确保不会丢失任何请求的处理结果。

WaitGroup的注意事项

在使用 WaitGroup 时,有一些注意事项需要我们关注,以确保程序的正确性和稳定性。

避免重复使用

WaitGroup 设计用于一次性等待一组 goroutine。如果在 Wait 调用之后再次使用 Add 方法增加计数器,可能会导致未定义行为。

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine started")
        time.Sleep(2 * time.Second)
        fmt.Println("Goroutine finished")
    }()
    wg.Wait()
    // 不应该在这里再次调用Add
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Another goroutine started")
        time.Sleep(2 * time.Second)
        fmt.Println("Another goroutine finished")
    }()
    wg.Wait()
}

在上述代码中,第二次调用 Add 是不推荐的做法。如果确实需要多次等待不同组的 goroutine,可以考虑使用多个 WaitGroup

注意计数器的增减平衡

在使用 AddDone 方法时,一定要确保计数器的增加和减少操作是平衡的。如果 Add 的次数多于 DoneWait 方法可能会永远阻塞;反之,如果 Done 的次数多于 Add,会导致程序崩溃。

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(2)
    go func() {
        defer wg.Done()
        fmt.Println("First goroutine")
    }()
    // 这里少了一次wg.Done(),导致Wait永远阻塞
    wg.Wait()
    fmt.Println("Main finished")
}

在这个例子中,由于少了一次 wg.Done()wg.Wait() 会永远阻塞,导致程序无法正常结束。

避免在 goroutine 外部调用 Done

Done 方法应该在 goroutine 内部调用,以确保计数器的减少与 goroutine 的完成相对应。在 goroutine 外部调用 Done 可能会导致计数器减少的时机不正确。

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    // 错误:在goroutine外部调用Done
    wg.Done()
    go func() {
        fmt.Println("Goroutine started")
        time.Sleep(2 * time.Second)
        fmt.Println("Goroutine finished")
    }()
    wg.Wait()
    fmt.Println("Main finished")
}

在上述代码中,在启动 goroutine 之前调用 wg.Done() 是错误的,这可能会导致 Wait 方法过早返回,而 goroutine 还未完成。

通过合理使用 WaitGroup,并注意上述事项,我们可以有效地保障Go语言程序在并发环境下的安全性和稳定性,充分发挥Go语言并发编程的优势。无论是简单的并发任务,还是复杂的分布式系统,WaitGroup 都是一个不可或缺的工具。