MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go WaitGroup的实现原理揭秘

2021-09-014.2k 阅读

Go WaitGroup 的基本使用

在 Go 语言中,WaitGroup 是一个非常实用的同步原语,用于等待一组 goroutine 完成执行。下面通过一个简单的示例来展示其基本用法:

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模拟一些工作
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 3
    for i := 1; i <= numWorkers; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
    fmt.Println("All workers are done")
}

在上述代码中,我们创建了一个 WaitGroup 实例 wg。在启动每个 goroutine 之前,通过 wg.Add(1) 来增加等待组的计数。在 worker 函数中,使用 defer wg.Done() 来标记该 goroutine 完成工作,这实际上是将等待组的计数减 1。最后,在 main 函数中调用 wg.Wait(),它会阻塞当前 goroutine,直到等待组的计数变为 0,即所有的 goroutine 都调用了 Done

WaitGroup 的数据结构

要深入理解 WaitGroup 的实现原理,我们需要先了解其底层的数据结构。在 Go 语言的源码中,WaitGroup 的定义如下:

// src/sync/waitgroup.go
type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

这里的 noCopy 是一个标记结构体,用于防止 WaitGroup 被复制,因为复制 WaitGroup 可能会导致未定义的行为。而 state1 这个数组则是 WaitGroup 实现的关键。它实际上表示了两个值:计数器和等待者的数量,以及一个信号量。

状态表示

state1 数组通过位运算来表示不同的状态信息。假设 statestate1 数组转换为一个 64 位整数(在 64 位系统上),低 32 位表示等待组的计数器(counter),高 32 位表示等待者的数量(waiters)。而信号量则是通过 runtime_Semacquireruntime_Semrelease 等函数来操作。

Add 方法的实现

Add 方法用于增加等待组的计数器。其实现代码如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
        return
    }
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

Add 方法中,首先通过 wg.state() 获取状态指针 statep 和信号量指针 semap。然后使用 atomic.AddUint64 原子地增加计数器。如果增加后的计数器为负数,会触发 panic。如果在有等待者的情况下,同时调用 AddWait 且增加的数量与当前等待者数量相等,也会触发 panic,这是为了防止错误的使用。如果增加后计数器大于 0 或者没有等待者,直接返回。否则,将状态设置为 0,并释放所有等待者的信号量。

Done 方法的实现

Done 方法实际上是 Add(-1) 的快捷方式,其实现如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

通过调用 Add(-1),将等待组的计数器减 1。

Wait 方法的实现

Wait 方法用于阻塞当前 goroutine,直到等待组的计数器变为 0。其实现代码如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            return
        }
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

Wait 方法中,首先获取状态指针和信号量指针。然后进入一个循环,每次循环加载当前状态。如果计数器为 0,说明所有工作已经完成,直接返回。否则,尝试通过 atomic.CompareAndSwapUint64 原子地增加等待者的数量。如果成功,调用 runtime_Semacquire 阻塞当前 goroutine,等待信号量释放。当信号量被释放后,再次检查状态,如果状态不为 0,说明 WaitGroup 在之前的 Wait 返回之前被重用了,触发 panic

应用场景

  1. 多任务并行处理:在许多实际应用中,我们可能需要并行执行多个任务,例如从多个数据源获取数据,然后在所有数据获取完成后进行合并处理。通过 WaitGroup 可以方便地实现这种场景。
package main

import (
    "fmt"
    "sync"
)

func fetchData(source int, result *[]int, wg *sync.WaitGroup) {
    defer wg.Done()
    // 模拟数据获取
    data := []int{source * 10, source * 10 + 1}
    *result = append(*result, data...)
}

func main() {
    var wg sync.WaitGroup
    var allData []int
    numSources := 3
    for i := 1; i <= numSources; i++ {
        wg.Add(1)
        go fetchData(i, &allData, &wg)
    }
    wg.Wait()
    fmt.Println("All data fetched:", allData)
}
  1. 服务启动与关闭:在开发服务器应用时,可能需要启动多个 goroutine 来处理不同的服务功能,例如 HTTP 服务、数据库连接池管理等。在关闭服务器时,需要确保所有这些 goroutine 都能安全地停止。WaitGroup 可以用于等待所有服务 goroutine 完成清理工作。
package main

import (
    "fmt"
    "net/http"
    "sync"
    "time"
)

func startHTTPServer(wg *sync.WaitGroup) {
    defer wg.Done()
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "Hello, World!")
    })
    fmt.Println("HTTP server started")
    http.ListenAndServe(":8080", nil)
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go startHTTPServer(&wg)

    // 模拟一些运行时间
    time.Sleep(5 * time.Second)

    // 关闭服务器逻辑(这里简单示意)
    fmt.Println("Shutting down server")
    // 这里应该有实际的关闭 HTTP 服务器的代码

    wg.Wait()
    fmt.Println("Server is shut down")
}

注意事项

  1. 计数器的增减必须平衡:如果在 Add 中增加的数量与 Done 中减少的数量不相等,可能会导致 Wait 永远阻塞。例如,如果忘记调用 Done 或者多次调用 Add 而没有相应的 Done,都会出现问题。
  2. 避免重复使用:一旦 Wait 返回,WaitGroup 的状态就变为未使用状态。如果再次使用 WaitGroup 而没有重新初始化(通过 Add 重新设置计数器),可能会导致未定义的行为。例如,在上面 Wait 方法的实现中,如果在 Wait 返回后再次调用 Wait 而没有重新 Add,就会触发 panic
  3. 并发安全:虽然 WaitGroup 本身是并发安全的,但在使用时仍需注意相关的操作是否在并发环境下是安全的。例如,如果在多个 goroutine 中同时调用 Add 并且依赖计数器的中间状态,可能会出现竞争条件。

与其他同步原语的比较

  1. sync.Mutex 的比较sync.Mutex 主要用于保护共享资源,防止多个 goroutine 同时访问导致数据竞争。而 WaitGroup 则侧重于等待一组 goroutine 完成工作,它并不直接用于保护数据。例如,在一个多 goroutine 读写共享数据的场景中,需要使用 Mutex 来保护数据,而如果需要等待所有读写操作完成后再进行下一步处理,可以使用 WaitGroup
  2. sync.Cond 的比较sync.Cond 通常与 Mutex 结合使用,用于在满足特定条件时通知等待的 goroutine。WaitGroup 则更简单直接,只关注一组 goroutine 的完成情况,不需要设置复杂的条件。例如,在生产者 - 消费者模型中,如果消费者需要等待生产者生产一定数量的数据后再进行消费,可以使用 sync.Cond;而如果只是需要等待所有生产者完成生产任务,WaitGroup 是更好的选择。

总结与展望

WaitGroup 作为 Go 语言中常用的同步原语,为我们管理 goroutine 的并发执行提供了便利。通过深入了解其实现原理,我们可以更准确地使用它,避免常见的错误。在实际开发中,合理运用 WaitGroup 可以提高程序的并发性能和稳定性。随着 Go 语言的不断发展,未来可能会对 WaitGroup 进行优化或扩展,以满足更复杂的并发需求。例如,可能会增加一些新的方法来更灵活地控制等待组的行为,或者在性能上进行进一步的提升,以适应大规模并发场景。同时,在使用 WaitGroup 时,我们也应该结合其他同步原语和并发编程技巧,构建出高效、可靠的并发程序。