MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go 语言 WaitGroup 的实现原理与并发控制

2023-08-133.7k 阅读

Go 语言并发编程基础

在深入探讨 WaitGroup 之前,我们先来回顾一下 Go 语言并发编程的基础概念。Go 语言从诞生之初就对并发编程提供了原生且强大的支持。通过 goroutinechannel 这两个核心机制,Go 语言使得编写高并发程序变得相对简洁和高效。

goroutine

goroutine 是 Go 语言中实现并发的轻量级线程。与传统操作系统线程相比,goroutine 的创建和销毁开销极小。在 Go 语言中,只需在函数调用前加上 go 关键字,就可以创建一个新的 goroutine。例如:

package main

import (
    "fmt"
    "time"
)

func say(s string) {
    for i := 0; i < 5; i++ {
        time.Sleep(100 * time.Millisecond)
        fmt.Println(s)
    }
}

func main() {
    go say("world")
    say("hello")
}

在上述代码中,go say("world") 创建了一个新的 goroutine 来执行 say("world") 函数,而 say("hello") 则在主 goroutine 中执行。这两个 goroutine 是并发执行的。

channel

channel 是 Go 语言中用于 goroutine 之间通信和同步的机制。它可以看作是一个类型化的管道,数据可以从一端发送,从另一端接收。通过 channel,可以避免共享内存带来的竞态条件问题。例如:

package main

import (
    "fmt"
)

func sum(s []int, c chan int) {
    sum := 0
    for _, v := range s {
        sum += v
    }
    c <- sum
}

func main() {
    s := []int{7, 2, 8, -9, 4, 0}

    c := make(chan int)
    go sum(s[:len(s)/2], c)
    go sum(s[len(s)/2:], c)
    x, y := <-c, <-c

    fmt.Println(x, y, x+y)
}

在这段代码中,我们创建了一个 channel c,并启动了两个 goroutine 分别计算切片 s 的前半部分和后半部分的和。然后通过从 channel 接收数据,获取这两个计算结果并最终输出总和。

WaitGroup 概述

WaitGroup 是 Go 语言标准库 sync 包中的一个类型,用于实现 goroutine 的同步。它允许一个 goroutine 等待一组 goroutine 完成各自的任务。WaitGroup 内部维护了一个计数器,通过 Add 方法增加计数器的值,通过 Done 方法减少计数器的值,通过 Wait 方法阻塞当前 goroutine,直到计数器的值变为 0。

WaitGroup 的基本使用

以下是一个简单的示例,展示了 WaitGroup 的基本用法:

package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
    fmt.Println("All workers done")
}

在这个例子中,我们创建了一个 WaitGroup 实例 wg。在循环中,每次启动一个新的 goroutine 时,调用 wg.Add(1) 增加计数器的值。在 worker 函数中,通过 defer wg.Done() 来减少计数器的值。最后,在主 goroutine 中调用 wg.Wait(),这会阻塞主 goroutine,直到所有的 worker goroutine 都调用了 wg.Done(),即计数器的值变为 0,此时主 goroutine 继续执行并输出 "All workers done"。

WaitGroup 的实现原理

WaitGroup 的实现基于 Go 语言的 sync 包中的一些底层同步机制,主要包括原子操作和信号量。

数据结构

WaitGroup 的核心数据结构定义在 src/sync/waitgroup.go 中:

// A WaitGroup waits for a collection of goroutines to finish.
// The main goroutine calls Add to set the number of
// goroutines to wait for. Then each of the goroutines
// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 386 and arm
    // do not have 64-bit hardware alignment for 64-bit words.
    // For this reason we allocate 12 bytes and then use the aligned 8 bytes in them as state.
    state1 [3]uint32
}

可以看到,WaitGroup 结构体中包含一个 noCopy 字段,它用于防止 WaitGroup 被复制(因为复制 WaitGroup 可能会导致同步状态不一致)。另外,state1 字段是一个包含 3 个 uint32 的数组,其中高 32 位用于表示计数器的值,低 32 位用于表示等待的 goroutine 的数量。

Add 方法

Add 方法用于增加 WaitGroup 的计数器值。其实现如下:

// Add adds delta, which may be negative, to the WaitGroup counter.
// If the counter becomes zero, all goroutines blocked on Wait are released.
// If the counter goes negative, Add panics.
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if delta > 0 && v == int32(delta) {
        // The first increment after the counter was zero must not wake
        // any goroutines. This would introduce a race with Wait.
        return
    }
    if w != 0 {
        runtime_Semrelease(semap, false, 0)
    }
}

Add 方法中,首先通过 wg.state() 获取 statep(指向状态值的指针)和 semap(指向信号量的指针)。然后使用原子操作 atomic.AddUint64 增加计数器的值(通过将 delta 左移 32 位后与当前状态值相加)。接着检查计数器是否为负数,如果是则 panic。如果增加后的计数器值等于 delta 且之前计数器为 0(表示这是计数器从 0 变为非 0 的首次增加),则直接返回,因为此时不应该唤醒任何等待的 goroutine。否则,如果有等待的 goroutine(即 w != 0),则调用 runtime_Semrelease 释放信号量,唤醒等待的 goroutine

Done 方法

Done 方法实际上是 Add(-1) 的便捷调用,其实现如下:

// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

这样设计使得在 goroutine 中调用 wg.Done() 更加方便,而不需要手动传入 -1 调用 Add 方法。

Wait 方法

Wait 方法用于阻塞当前 goroutine,直到 WaitGroup 的计数器变为 0。其实现如下:

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            // Counter is 0, no need to wait.
            return
        }
        // Increment waiters count.
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

Wait 方法中,首先获取 statepsemap。然后在一个无限循环中,通过 atomic.LoadUint64 加载当前状态值,并提取计数器的值 v。如果计数器为 0,则直接返回,因为所有任务已经完成。否则,使用 atomic.CompareAndSwapUint64 尝试增加等待者的数量(通过将状态值加 1)。如果成功增加等待者数量,则调用 runtime_Semacquire 获取信号量,进入等待状态。当被唤醒后,再次检查状态值,如果不为 0,则 panic,表示 WaitGroup 在之前的 Wait 调用返回之前被重用了。

WaitGroup 在实际场景中的应用

批量任务处理

在很多实际应用中,我们需要并发地执行一组任务,并在所有任务完成后进行下一步操作。例如,在一个爬虫程序中,可能需要并发地抓取多个网页的数据,然后对这些数据进行汇总分析。

package main

import (
    "fmt"
    "io/ioutil"
    "net/http"
    "sync"
)

func fetch(url string, wg *sync.WaitGroup) {
    defer wg.Done()
    resp, err := http.Get(url)
    if err != nil {
        fmt.Printf("Error fetching %s: %v\n", url, err)
        return
    }
    defer resp.Body.Close()
    _, err = ioutil.ReadAll(resp.Body)
    if err != nil {
        fmt.Printf("Error reading %s: %v\n", url, err)
        return
    }
    fmt.Printf("Fetched %s successfully\n", url)
}

func main() {
    urls := []string{
        "https://www.example.com",
        "https://www.google.com",
        "https://www.github.com",
    }
    var wg sync.WaitGroup
    for _, url := range urls {
        wg.Add(1)
        go fetch(url, &wg)
    }
    wg.Wait()
    fmt.Println("All fetches completed")
}

在这个例子中,我们定义了 fetch 函数用于抓取指定 URL 的内容。在 main 函数中,遍历 URL 列表,为每个 URL 创建一个 goroutine 来执行 fetch 操作,并通过 WaitGroup 来等待所有抓取任务完成。

并行计算

在科学计算或数据分析领域,经常需要对大量数据进行并行计算。例如,计算一个大型矩阵的乘法,我们可以将矩阵划分成多个子矩阵,并发地计算这些子矩阵的乘积,最后合并结果。

package main

import (
    "fmt"
    "sync"
)

func matrixMultiplySub(a, b [][]int, startRow, endRow, startCol, endCol int, result [][]int, wg *sync.WaitGroup) {
    defer wg.Done()
    for i := startRow; i < endRow; i++ {
        for j := startCol; j < endCol; j++ {
            for k := 0; k < len(b); k++ {
                result[i][j] += a[i][k] * b[k][j]
            }
        }
    }
}

func matrixMultiply(a, b [][]int) [][]int {
    rowsA := len(a)
    colsA := len(a[0])
    colsB := len(b[0])
    result := make([][]int, rowsA)
    for i := range result {
        result[i] = make([]int, colsB)
    }

    var wg sync.WaitGroup
    numWorkers := 4
    rowStep := rowsA / numWorkers
    for i := 0; i < numWorkers; i++ {
        startRow := i * rowStep
        endRow := (i + 1) * rowStep
        if i == numWorkers-1 {
            endRow = rowsA
        }
        wg.Add(1)
        go matrixMultiplySub(a, b, startRow, endRow, 0, colsB, result, &wg)
    }
    wg.Wait()
    return result
}

func main() {
    a := [][]int{
        {1, 2},
        {3, 4},
    }
    b := [][]int{
        {5, 6},
        {7, 8},
    }
    result := matrixMultiply(a, b)
    for _, row := range result {
        fmt.Println(row)
    }
}

在这个矩阵乘法的例子中,我们将矩阵 a 按行划分成多个部分,每个部分由一个 goroutine 负责计算与矩阵 b 的乘积,并通过 WaitGroup 确保所有计算完成后返回最终结果。

WaitGroup 使用的注意事项

避免重复添加

重复调用 Add 方法且参数为正数,可能会导致 Wait 方法永远阻塞。例如:

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        time.Sleep(time.Second)
        wg.Done()
    }()
    wg.Add(1) // 重复添加,导致 Wait 永远阻塞
    wg.Wait()
    fmt.Println("Should not reach here")
}

在这个例子中,主 goroutine 先调用 wg.Add(1) 并启动一个 goroutine,在这个 goroutine 中调用 wg.Done() 后,主 goroutine 又额外调用了一次 wg.Add(1),这使得计数器永远不会变为 0,wg.Wait() 会一直阻塞。

避免提前释放

如果在所有需要调用 Donegoroutine 完成之前调用 Wait,可能会导致部分任务未完成就继续执行后续代码。例如:

package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)
    wg.Done()
}

func main() {
    var wg sync.WaitGroup
    wg.Add(3)
    go worker(1, &wg)
    go worker(2, &wg)
    wg.Wait() // 过早调用 Wait,可能有 goroutine 还未开始执行
    go worker(3, &wg)
    fmt.Println("All workers done?")
}

在这个例子中,主 goroutine 启动了两个 goroutine 后就调用了 wg.Wait(),此时第三个 goroutine 还未启动,导致输出 "All workers done?" 时,第三个 goroutine 可能还未执行完毕。

避免重用

WaitGroup 设计为一次性使用,如果在 Wait 方法返回之前重用 WaitGroup,可能会导致未定义行为。例如:

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        time.Sleep(time.Second)
        wg.Done()
    }()
    wg.Wait()
    wg.Add(1) // 重用 WaitGroup,可能导致未定义行为
    go func() {
        time.Sleep(time.Second)
        wg.Done()
    }()
    wg.Wait()
    fmt.Println("Finished")
}

在这个例子中,第一次 Wait 完成后,再次使用 wg.Add(1)wg.Wait(),这违反了 WaitGroup 的设计原则,可能会导致不可预测的结果。

与其他同步机制的比较

与 channel 比较

channel 主要用于 goroutine 之间的通信,通过发送和接收数据来实现同步。而 WaitGroup 更侧重于等待一组 goroutine 完成任务,不涉及数据的传递。例如,在一个生产者 - 消费者模型中,channel 可以用于生产者向消费者传递数据,而 WaitGroup 可以用于等待所有生产者完成生产任务。

与 Mutex 比较

Mutex(互斥锁)主要用于保护共享资源,防止多个 goroutine 同时访问导致竞态条件。WaitGroup 并不直接用于保护资源,而是用于协调 goroutine 的执行顺序,确保一组 goroutine 完成后再进行下一步操作。例如,在一个多 goroutine 访问共享数据库的场景中,Mutex 用于保护数据库连接,而 WaitGroup 可以用于等待所有数据库操作完成。

总结

WaitGroup 是 Go 语言并发编程中一个非常实用的工具,通过对计数器的原子操作和信号量机制,实现了对一组 goroutine 的同步控制。在实际应用中,无论是批量任务处理还是并行计算等场景,WaitGroup 都能发挥重要作用。但在使用过程中,需要注意避免重复添加、提前释放和重用等常见问题,以确保程序的正确性和稳定性。同时,理解 WaitGroup 与其他同步机制如 channelMutex 的区别,有助于我们在不同的并发场景中选择合适的工具,编写出高效、健壮的并发程序。

通过深入理解 WaitGroup 的实现原理和使用方法,我们能够更好地利用 Go 语言的并发特性,提升程序的性能和响应能力,为构建大型、复杂的分布式系统奠定坚实的基础。无论是在网络编程、数据分析还是云计算等领域,WaitGroup 都将是我们并发编程工具箱中的得力助手。