MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go语言WaitGroup并发控制详解

2023-08-175.0k 阅读

1. Go语言并发编程基础

在深入了解 WaitGroup 之前,我们先来回顾一下Go语言并发编程的一些基础知识。Go语言的并发模型基于 goroutinechannel

1.1 goroutine

goroutine 是Go语言中轻量级的线程实现。与传统的线程不同,goroutine 的创建和销毁开销极小。一个程序可以轻松创建数以万计的 goroutine。通过 go 关键字来启动一个 goroutine,例如:

package main

import (
    "fmt"
)

func printHello() {
    fmt.Println("Hello, Goroutine!")
}

func main() {
    go printHello()
    fmt.Println("Main function")
}

在上述代码中,go printHello() 启动了一个新的 goroutine 来执行 printHello 函数。主 goroutine 并不会等待新启动的 goroutine 完成,而是继续执行后续的代码,即打印 "Main function"。这里会发现,程序可能在新 goroutine 执行 printHello 函数之前就结束了,这是因为主 goroutine 结束后整个程序就结束了。

1.2 channel

channel 是Go语言中用于在 goroutine 之间进行通信的管道。它可以确保数据在不同 goroutine 之间安全传递,避免了共享内存带来的竞态条件问题。创建一个 channel 如下:

ch := make(chan int)

这创建了一个可以传递 int 类型数据的 channel。向 channel 发送数据使用 <- 操作符:

ch <- 10

channel 接收数据也使用 <- 操作符:

data := <-ch

一个完整的示例如下:

package main

import (
    "fmt"
)

func sendData(ch chan int) {
    ch <- 42
}

func main() {
    ch := make(chan int)
    go sendData(ch)
    data := <-ch
    fmt.Println("Received data:", data)
}

在这个例子中,sendData 函数在一个新的 goroutine 中向 channel 发送数据,主 goroutinechannel 接收数据并打印。

2. WaitGroup 简介

WaitGroup 是Go标准库 sync 包中的一个类型,用于实现 goroutine 的同步。它可以让一个 goroutine 等待一组 goroutine 全部完成。WaitGroup 内部维护一个计数器,通过 Add 方法增加计数器的值,通过 Done 方法减少计数器的值,Wait 方法会阻塞当前 goroutine,直到计数器的值变为0。

3. WaitGroup 的基本用法

3.1 简单示例

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    // 模拟一些工作
    for i := 0; i < 3; i++ {
        fmt.Printf("Worker %d working: %d\n", id, i)
    }
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 3

    for i := 1; i <= numWorkers; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }

    wg.Wait()
    fmt.Println("All workers are done")
}

在上述代码中:

  1. 首先在 main 函数中创建了一个 WaitGroup 实例 wg
  2. 通过循环启动了 numWorkersgoroutine,每个 goroutine 调用 worker 函数。在启动每个 goroutine 之前,调用 wg.Add(1) 增加计数器的值。
  3. worker 函数中,使用 defer wg.Done() 来确保函数结束时计数器减1。defer 关键字会在函数返回之前执行其后面的语句,这样无论 worker 函数以何种方式结束,计数器都会正确减少。
  4. 最后在 main 函数中调用 wg.Wait(),主 goroutine 会阻塞在这里,直到所有 worker goroutine 都调用了 wg.Done(),计数器变为0,此时主 goroutine 继续执行后续代码,打印 "All workers are done"。

3.2 避免重复添加计数器

在使用 WaitGroup 时,要特别注意避免重复添加计数器。如果多次调用 Add 方法而没有相应数量的 Done 调用,Wait 方法可能会永远阻塞。例如:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine started")
    }()
    wg.Add(1) // 重复添加,可能导致问题
    wg.Wait()
}

在这个例子中,第二次调用 wg.Add(1) 是多余的,因为只启动了一个 goroutine。这可能会导致 wg.Wait() 永远阻塞,因为没有额外的 Done 调用来减少这个多余添加的计数器。

4. WaitGroup 实现原理

WaitGroup 的实现基于Go语言的运行时调度器和同步原语。它内部使用了一个计数器和一个信号量。

4.1 数据结构

在Go语言的标准库源码(src/sync/waitgroup.go)中,WaitGroup 的定义如下:

type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

noCopy 是一个标记结构体,用于防止 WaitGroup 被复制,因为复制 WaitGroup 可能会导致未定义行为。state1 数组用于存储计数器的值和信号量的状态。具体来说,前两个 uint32 用于计数器和信号量,第三个 uint32 用于存储 WaitGroup 的一些状态信息。

4.2 Add 方法

Add 方法用于增加计数器的值。其实现大致如下:

func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    for ; state>>32 != uint64(v); state = atomic.AddUint64(statep, 0) {
    }
    if v > 0 || delta < 0 {
        return
    }
    for ; state != 0; state = atomic.AddUint64(statep, 0) {
        runtime_Semrelease(semap, false, 0)
    }
}

这里使用了原子操作 atomic.AddUint64 来安全地增加计数器的值。如果增加后计数器变为负数,会触发 panic,因为计数器不应该为负。如果计数器变为0,会释放信号量,通知等待的 goroutine

4.3 Done 方法

Done 方法实际上是 Add(-1) 的快捷方式:

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

它通过调用 Add(-1) 来减少计数器的值。同样,这里也使用原子操作确保计数器的安全修改。

4.4 Wait 方法

Wait 方法用于阻塞当前 goroutine,直到计数器变为0:

func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            return
        }
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            atomic.AddUint64(statep, -1)
        }
    }
}

Wait 方法首先加载当前计数器的值,如果计数器为0,直接返回。否则,通过 runtime_Semacquire 尝试获取信号量,获取成功后减少计数器的值。如果获取信号量失败,会再次尝试加载计数器的值并获取信号量,直到计数器变为0或成功获取信号量。

5. WaitGroup 在复杂场景中的应用

5.1 分组等待

有时候我们需要将 goroutine 分成不同的组,并分别等待每组完成。可以通过多个 WaitGroup 实例来实现。例如,假设有两组任务,一组任务处理数据的读取,另一组任务处理数据的计算:

package main

import (
    "fmt"
    "sync"
)

func readData(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Read task %d started\n", id)
    // 模拟读取数据
    fmt.Printf("Read task %d done\n", id)
}

func processData(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Process task %d started\n", id)
    // 模拟数据处理
    fmt.Printf("Process task %d done\n", id)
}

func main() {
    var readWG sync.WaitGroup
    var processWG sync.WaitGroup

    numReadTasks := 2
    numProcessTasks := 3

    for i := 1; i <= numReadTasks; i++ {
        readWG.Add(1)
        go readData(i, &readWG)
    }

    for i := 1; i <= numProcessTasks; i++ {
        processWG.Add(1)
        go processData(i, &processWG)
    }

    readWG.Wait()
    fmt.Println("All read tasks are done")

    processWG.Wait()
    fmt.Println("All process tasks are done")
}

在这个例子中,分别为读取任务和处理任务创建了 WaitGroup 实例 readWGprocessWG。通过 readWG.Wait()processWG.Wait() 分别等待两组任务完成。

5.2 嵌套使用 WaitGroup

在一些复杂的业务逻辑中,可能会出现 goroutine 内部又启动新的 goroutine 的情况,这时候可以嵌套使用 WaitGroup。例如:

package main

import (
    "fmt"
    "sync"
)

func innerTask(id int, innerWG *sync.WaitGroup) {
    defer innerWG.Done()
    fmt.Printf("Inner task %d started\n", id)
    // 模拟一些工作
    fmt.Printf("Inner task %d done\n", id)
}

func outerTask(id int, outerWG *sync.WaitGroup) {
    var innerWG sync.WaitGroup
    numInnerTasks := 2

    fmt.Printf("Outer task %d started\n", id)

    for i := 1; i <= numInnerTasks; i++ {
        innerWG.Add(1)
        go innerTask(i, &innerWG)
    }

    innerWG.Wait()
    fmt.Printf("All inner tasks in outer task %d are done\n", id)

    outerWG.Done()
}

func main() {
    var outerWG sync.WaitGroup
    numOuterTasks := 2

    for i := 1; i <= numOuterTasks; i++ {
        outerWG.Add(1)
        go outerTask(i, &outerWG)
    }

    outerWG.Wait()
    fmt.Println("All outer tasks are done")
}

在这个示例中,outerTask 函数内部启动了多个 innerTask goroutine,并使用 innerWG 来等待它们完成。outerTask 完成所有内部任务后,调用 outerWG.Done() 通知主 goroutine 它已完成。主 goroutine 使用 outerWG.Wait() 等待所有 outerTask 完成。

6. WaitGroup 与其他同步机制的结合使用

6.1 WaitGroup 与 Mutex

在并发编程中,Mutex(互斥锁)用于保护共享资源,防止多个 goroutine 同时访问。WaitGroupMutex 可以结合使用,例如在多个 goroutine 对共享资源进行读写操作时,使用 Mutex 保护资源,使用 WaitGroup 等待所有操作完成。

package main

import (
    "fmt"
    "sync"
)

var (
    sharedData int
    mu         sync.Mutex
)

func readData(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    mu.Lock()
    fmt.Printf("Reader %d reads data: %d\n", id, sharedData)
    mu.Unlock()
}

func writeData(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    mu.Lock()
    sharedData = id
    fmt.Printf("Writer %d writes data: %d\n", id, sharedData)
    mu.Unlock()
}

func main() {
    var wg sync.WaitGroup
    numReaders := 2
    numWriters := 3

    for i := 1; i <= numWriters; i++ {
        wg.Add(1)
        go writeData(i, &wg)
    }

    for i := 1; i <= numReaders; i++ {
        wg.Add(1)
        go readData(i, &wg)
    }

    wg.Wait()
    fmt.Println("All read and write operations are done")
}

在这个例子中,Mutex 用于保护 sharedData 共享变量,确保读写操作的原子性。WaitGroup 用于等待所有读写 goroutine 完成。

6.2 WaitGroup 与 channel

channelWaitGroup 也可以很好地结合。channel 用于在 goroutine 之间传递数据,WaitGroup 用于同步 goroutine。例如,在一个生产者 - 消费者模型中:

package main

import (
    "fmt"
    "sync"
)

func producer(ch chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    for i := 1; i <= 5; i++ {
        ch <- i
        fmt.Printf("Produced: %d\n", i)
    }
    close(ch)
}

func consumer(ch chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    for data := range ch {
        fmt.Printf("Consumed: %d\n", data)
    }
}

func main() {
    var wg sync.WaitGroup
    ch := make(chan int)

    wg.Add(1)
    go producer(ch, &wg)

    numConsumers := 2
    for i := 1; i <= numConsumers; i++ {
        wg.Add(1)
        go consumer(ch, &wg)
    }

    wg.Wait()
    fmt.Println("All production and consumption are done")
}

在这个生产者 - 消费者模型中,生产者 goroutine 通过 channel 向消费者 goroutine 发送数据。WaitGroup 用于确保生产者和所有消费者 goroutine 都完成任务。

7. WaitGroup 使用中的常见问题与解决方法

7.1 死锁问题

死锁是并发编程中常见的问题,在使用 WaitGroup 时也可能出现。例如,以下代码会导致死锁:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        wg.Wait()
        fmt.Println("Goroutine inside")
        wg.Done()
    }()
    wg.Wait()
    fmt.Println("Main function")
}

在这个例子中,新启动的 goroutine 调用 wg.Wait() 等待计数器变为0,而主 goroutine 也调用 wg.Wait() 等待,由于没有 Done 调用,计数器永远不会变为0,从而导致死锁。解决方法是确保在适当的地方调用 Done 方法,例如:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        fmt.Println("Goroutine inside")
        wg.Done()
    }()
    wg.Wait()
    fmt.Println("Main function")
}

7.2 未正确初始化计数器

如果没有正确初始化 WaitGroup 的计数器,可能会导致 Wait 方法行为异常。例如:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine working")
    }()
    wg.Wait()
    fmt.Println("All done")
}

在这个例子中,没有调用 wg.Add(1) 初始化计数器,wg.Done() 会将计数器减为负数,导致 panic。正确的做法是在启动 goroutine 之前调用 wg.Add(1)

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine working")
    }()
    wg.Wait()
    fmt.Println("All done")
}

8. 性能考量

在使用 WaitGroup 时,虽然它提供了方便的同步机制,但也需要考虑性能问题。由于 WaitGroup 内部使用了原子操作和信号量,频繁地调用 AddDoneWait 方法会带来一定的开销。

8.1 减少不必要的操作

尽量减少在循环中频繁调用 AddDone 方法。例如,如果要启动多个 goroutine,可以一次性调用 Add 方法增加计数器的值,而不是在每次启动 goroutine 时调用 Add

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d working\n", id)
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 10000
    wg.Add(numWorkers)

    for i := 1; i <= numWorkers; i++ {
        go worker(i, &wg)
    }

    wg.Wait()
    fmt.Println("All workers are done")
}

在这个例子中,一次性调用 wg.Add(numWorkers) 增加计数器的值,而不是在每次启动 worker goroutine 时调用 Add,这样可以减少原子操作的次数,提高性能。

8.2 避免过度同步

不要在不必要的地方使用 WaitGroup 进行同步。如果某些 goroutine 之间没有依赖关系,不需要等待它们全部完成再进行下一步操作,就不应该使用 WaitGroup 来同步。过度同步会降低程序的并发性能。例如,如果有一组 goroutine 分别处理不同的独立任务,并且这些任务的结果不需要汇总后再进行下一步,那么就可以让它们异步执行,而不需要使用 WaitGroup 等待。

9. 总结

WaitGroup 是Go语言并发编程中非常重要的同步工具,它提供了一种简单有效的方式来等待一组 goroutine 完成。通过深入理解其基本用法、实现原理以及在复杂场景中的应用,我们可以更好地利用它来编写高效、正确的并发程序。同时,在使用过程中要注意避免常见问题,如死锁、未正确初始化计数器等,并关注性能考量,减少不必要的同步开销,以充分发挥Go语言并发编程的优势。在实际项目中,根据具体的业务需求,合理地结合 WaitGroup 与其他同步机制,如 Mutexchannel,可以构建出健壮、高性能的并发系统。