MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go语言WaitGroup的核心概念

2021-07-196.1k 阅读

Go语言并发编程基础

在深入探讨Go语言的WaitGroup之前,我们先来回顾一下Go语言并发编程的一些基础概念。Go语言从诞生之初就对并发编程提供了原生且高效的支持,这主要体现在其独特的goroutinechannel机制上。

goroutine

goroutine是Go语言中实现并发的核心组件,它类似于线程,但与传统线程有很大的区别。传统线程一般由操作系统内核管理,创建和销毁的开销较大,而goroutine是由Go运行时(runtime)管理的轻量级线程,其创建和销毁的开销极小。一个程序中可以轻松创建数以万计的goroutine,这使得Go语言在处理高并发场景时具有极高的效率。

下面是一个简单的goroutine示例代码:

package main

import (
    "fmt"
    "time"
)

func say(s string) {
    for i := 0; i < 5; i++ {
        time.Sleep(100 * time.Millisecond)
        fmt.Println(s)
    }
}

func main() {
    go say("world")
    say("hello")
}

在上述代码中,我们在main函数中启动了一个新的goroutine来执行say("world")函数,同时main函数本身也在一个goroutine中执行say("hello")。这两个goroutine并发执行,输出结果可能会交错。

channel

channel是Go语言中用于goroutine之间通信和同步的重要工具。它可以看作是一个管道,数据可以从一端发送进去,从另一端接收出来。通过channel,不同的goroutine之间可以安全地传递数据,避免了共享内存带来的并发问题。

下面是一个简单的channel示例代码:

package main

import (
    "fmt"
)

func sum(s []int, c chan int) {
    sum := 0
    for _, v := range s {
        sum += v
    }
    c <- sum
}

func main() {
    s := []int{7, 2, 8, -9, 4, 0}

    c := make(chan int)
    go sum(s[:len(s)/2], c)
    go sum(s[len(s)/2:], c)
    x, y := <-c, <-c

    fmt.Println(x, y, x+y)
}

在这个示例中,我们创建了一个channel c,并启动了两个goroutine分别计算切片s的前半部分和后半部分的和。计算结果通过channel c返回,最后在main函数中接收并输出结果。

WaitGroup的基本概念

在实际的并发编程中,我们常常需要等待一组goroutine全部完成后再继续执行后续的逻辑。这时候,WaitGroup就派上用场了。WaitGroup是Go标准库sync包中的一个类型,它提供了一种简单的方式来同步多个goroutine的执行。

原理概述

WaitGroup内部维护了一个计数器,这个计数器的值表示需要等待完成的goroutine的数量。当我们启动一个新的goroutine时,可以通过WaitGroupAdd方法增加计数器的值;当一个goroutine完成任务后,通过WaitGroupDone方法减少计数器的值;而主goroutine(或者其他需要等待的goroutine)可以通过WaitGroupWait方法阻塞,直到计数器的值变为0,即所有需要等待的goroutine都已完成。

基本使用方法

WaitGroup的使用主要涉及三个方法:AddDoneWait

  1. Add方法
    • 该方法用于增加WaitGroup内部计数器的值。它接受一个整数参数delta,通常情况下delta为1,表示新增一个需要等待的goroutine。如果delta为负数,则会减少计数器的值,但这种情况一般很少使用,因为Done方法专门用于安全地减少计数器的值。
  2. Done方法
    • Done方法实际上是Add(-1)的快捷方式,用于表示一个goroutine已经完成任务,将WaitGroup内部计数器的值减1。在goroutine的任务完成时,调用这个方法是非常重要的,否则WaitGroup的计数器不会归零,Wait方法将一直阻塞。
  3. Wait方法
    • Wait方法会阻塞调用它的goroutine,直到WaitGroup内部计数器的值变为0。也就是说,只有当所有通过Add方法添加的goroutine都调用了Done方法后,Wait方法才会返回,调用Waitgoroutine才能继续执行后续的代码。

WaitGroup的代码示例

下面通过几个具体的代码示例来深入理解WaitGroup的使用。

简单示例:等待多个goroutine完成

package main

import (
    "fmt"
    "sync"
    "time"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d starting\n", id)
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 3; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    fmt.Println("Waiting for all workers to finish...")
    wg.Wait()
    fmt.Println("All workers have finished.")
}

在上述代码中,我们创建了一个WaitGroup wg。然后通过循环启动了3个goroutine,每个goroutine在启动前调用wg.Add(1)增加计数器的值。在worker函数内部,使用defer wg.Done()来确保函数结束时计数器减1。main函数中调用wg.Wait()等待所有goroutine完成,只有当所有worker goroutine都调用了wg.Done()后,main函数才会继续执行并输出"All workers have finished."。

复杂示例:任务分组等待

有时候,我们可能需要对goroutine进行分组管理,不同组的goroutine完成任务的时间和逻辑可能不同,但我们仍然希望能够分别等待不同组的完成。这时候可以通过创建多个WaitGroup来实现。

package main

import (
    "fmt"
    "sync"
    "time"
)

func group1Worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Group 1 - Worker %d starting\n", id)
    time.Sleep(2 * time.Second)
    fmt.Printf("Group 1 - Worker %d done\n", id)
}

func group2Worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Group 2 - Worker %d starting\n", id)
    time.Sleep(1 * time.Second)
    fmt.Printf("Group 2 - Worker %d done\n", id)
}

func main() {
    var wg1 sync.WaitGroup
    var wg2 sync.WaitGroup

    for i := 1; i <= 2; i++ {
        wg1.Add(1)
        go group1Worker(i, &wg1)
    }

    for i := 1; i <= 3; i++ {
        wg2.Add(1)
        go group2Worker(i, &wg2)
    }

    fmt.Println("Waiting for Group 1 to finish...")
    wg1.Wait()
    fmt.Println("Group 1 has finished.")

    fmt.Println("Waiting for Group 2 to finish...")
    wg2.Wait()
    fmt.Println("Group 2 has finished.")

    fmt.Println("All groups have finished.")
}

在这个示例中,我们创建了两个WaitGroupwg1用于管理第一组goroutinewg2用于管理第二组goroutine。第一组goroutine模拟了执行时间较长的任务,第二组goroutine模拟了执行时间较短的任务。main函数先等待第一组goroutine完成,再等待第二组goroutine完成,最后输出所有组都已完成的信息。

WaitGroup的实现原理

深入了解WaitGroup的实现原理有助于我们更好地使用它,并且在遇到问题时能够更准确地排查和解决。

数据结构

在Go的标准库源码中,WaitGroup的定义如下:

// src/sync/waitgroup.go
type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

这里的noCopy是一个用于防止WaitGroup被复制的结构体,它主要是为了在编译时检测是否有对WaitGroup进行复制的操作,因为WaitGroup在被复制后可能会导致错误的同步行为。

state1这个数组则是WaitGroup实现的关键部分。state1实际上存储了两个重要的信息:计数器的值和等待的goroutine的数量。在64位系统上,state1的前两个uint32组成一个uint64,高32位存储等待的goroutine的数量,低32位存储计数器的值;第三个uint32用于信号通知。

Add方法实现

Add方法的实现代码如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
        return
    }
    for ; w != 0; w-- {
        runtime_Semrelease(semap)
    }
}

Add方法中,首先通过wg.state()获取state1数组对应的指针statep和信号量指针semap。然后使用atomic.AddUint64原子操作增加计数器的值(通过左移32位来更新计数器部分)。接着检查计数器的值是否为负数,如果是则抛出异常,因为计数器不应该为负。同时还检查是否存在AddWait并发调用导致的错误使用情况。如果计数器不为0或者等待的goroutine数量为0,则直接返回。否则,通过释放信号量来唤醒等待的goroutine

Done方法实现

Done方法实际上就是Add(-1)的快捷方式,其实现如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

这样就使得goroutine在完成任务时可以方便地调用Done方法来减少计数器的值。

Wait方法实现

Wait方法的实现代码如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            return
        }
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            if *statep&1 != 0 {
                runtime_Semrelease(semap)
            }
            return
        }
    }
}

Wait方法中,同样先获取state1数组指针statep和信号量指针semap。然后通过一个无限循环来检查计数器的值。如果计数器的值为0,说明所有goroutine都已完成,直接返回。否则,使用atomic.CompareAndSwapUint64原子操作尝试增加等待的goroutine数量(通过在state的高32位加1)。如果成功增加等待的goroutine数量,则调用runtime_Semacquire获取信号量,进入等待状态。当被唤醒后,如果state的最低位为1(表示有其他goroutine在唤醒等待的goroutine),则再次释放信号量,最后返回。

WaitGroup使用中的常见问题与注意事项

在使用WaitGroup时,有一些常见的问题和注意事项需要我们关注,以避免出现难以排查的并发错误。

计数器操作不当

  1. 负数计数器:如前文所述,Add方法不应该将计数器设置为负数,否则会导致panic。在实际使用中,要确保AddDone方法的调用次数匹配。如果在某个goroutine中忘记调用Done方法,或者错误地多次调用Add方法,都可能导致计数器出现异常情况。
  2. 并发调用问题Add方法不应该与Wait方法并发调用,否则可能会导致程序出现未定义行为。虽然Go运行时会对这种错误使用情况进行检测并抛出panic,但在编写代码时我们还是要尽量避免这种情况的发生。例如,不要在Wait方法执行期间动态地添加新的需要等待的goroutine

资源泄漏

如果在goroutine中使用了WaitGroup,但由于某些原因(如goroutine内部发生panic)导致Done方法没有被调用,那么WaitGroup的计数器将永远不会归零,调用Wait方法的goroutine将一直阻塞,从而导致资源泄漏。为了避免这种情况,可以在goroutine中使用defer语句来确保Done方法一定会被调用,即使goroutine内部发生了panic。例如:

func worker(wg *sync.WaitGroup) {
    defer wg.Done()
    // 这里是具体的工作逻辑,即使发生panic,wg.Done()也会被调用
}

嵌套使用

在复杂的并发场景中,可能会出现WaitGroup嵌套使用的情况。例如,一个goroutine内部又启动了多个子goroutine,并使用WaitGroup来等待这些子goroutine完成。在这种情况下,要特别注意AddDone方法的调用顺序和作用范围,确保计数器的增减操作在正确的位置执行,避免出现死锁或者逻辑错误。

WaitGroup与其他同步机制的比较

在Go语言的并发编程中,除了WaitGroup,还有其他一些同步机制,如mutex(互斥锁)、cond(条件变量)等。了解WaitGroup与这些同步机制的区别和适用场景,有助于我们在实际编程中选择最合适的工具。

WaitGroup与Mutex

  1. 功能侧重
    • Mutex主要用于保护共享资源,防止多个goroutine同时访问,从而避免数据竞争问题。它通过锁定和解锁操作来实现对共享资源的独占访问。
    • WaitGroup主要用于同步goroutine的执行,等待一组goroutine全部完成后再继续执行后续逻辑,并不涉及对共享资源的保护。
  2. 使用场景
    • 当需要确保某个资源在同一时间只能被一个goroutine访问时,应该使用Mutex。例如,对共享的数据库连接池进行操作时,为了避免多个goroutine同时修改连接池的状态,就需要使用Mutex来保护。
    • 当需要等待一组goroutine完成任务后再进行下一步操作时,如在并行计算任务完成后汇总结果,就适合使用WaitGroup

WaitGroup与Cond

  1. 功能侧重
    • Cond(条件变量)用于在共享资源的状态发生变化时,通知等待该条件的goroutine。它通常与Mutex配合使用,通过CondWait方法等待条件满足,通过SignalBroadcast方法通知等待的goroutine
    • WaitGroup则是简单地等待一组goroutine完成,不涉及共享资源状态变化的复杂通知机制。
  2. 使用场景
    • 当需要根据某个条件的变化来唤醒等待的goroutine时,比如在生产者 - 消费者模型中,当队列中有数据时通知消费者goroutine,就需要使用Cond
    • 而在只需要等待一组goroutine全部执行完毕的场景下,WaitGroup更加简洁直接,如并行下载多个文件后统一处理。

WaitGroup在实际项目中的应用场景

WaitGroup在实际项目中有许多应用场景,下面列举一些常见的场景。

并行任务处理

在需要并行处理多个任务的场景中,WaitGroup非常有用。例如,在一个数据处理系统中,需要同时对多个文件进行读取和处理。可以为每个文件处理任务启动一个goroutine,并使用WaitGroup来等待所有文件处理完成后再进行下一步的汇总或分析操作。

package main

import (
    "fmt"
    "io/ioutil"
    "path/filepath"
    "sync"
)

func processFile(filePath string, wg *sync.WaitGroup) {
    defer wg.Done()
    data, err := ioutil.ReadFile(filePath)
    if err != nil {
        fmt.Printf("Error reading file %s: %v\n", filePath, err)
        return
    }
    // 这里进行文件数据的具体处理
    fmt.Printf("Processed file %s: %d bytes\n", filePath, len(data))
}

func main() {
    var wg sync.WaitGroup
    files, err := filepath.Glob("*.txt")
    if err != nil {
        fmt.Printf("Error getting files: %v\n", err)
        return
    }
    for _, file := range files {
        wg.Add(1)
        go processFile(file, &wg)
    }
    wg.Wait()
    fmt.Println("All files processed.")
}

在上述代码中,我们通过filepath.Glob获取所有.txt文件,然后为每个文件启动一个goroutine进行处理。WaitGroup用于等待所有文件处理goroutine完成后输出所有文件已处理的信息。

服务启动与关闭

在开发服务器应用时,WaitGroup可以用于管理服务启动和关闭过程中的goroutine。例如,在启动服务器时,可能需要初始化多个组件,每个组件的初始化可以在一个单独的goroutine中进行。在关闭服务器时,也需要等待所有相关的goroutine完成清理工作。

package main

import (
    "fmt"
    "net/http"
    "sync"
    "time"
)

func startComponent(componentName string, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Starting %s...\n", componentName)
    // 模拟组件启动过程
    time.Sleep(2 * time.Second)
    fmt.Printf("%s started.\n", componentName)
}

func stopComponent(componentName string, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Stopping %s...\n", componentName)
    // 模拟组件关闭过程
    time.Sleep(1 * time.Second)
    fmt.Printf("%s stopped.\n", componentName)
}

func main() {
    var startWG sync.WaitGroup
    var stopWG sync.WaitGroup

    components := []string{"Database", "Cache", "API Server"}
    for _, component := range components {
        startWG.Add(1)
        go startComponent(component, &startWG)
    }
    startWG.Wait()
    fmt.Println("All components started. Server is running.")

    // 模拟服务器运行一段时间
    time.Sleep(5 * time.Second)

    for _, component := range components {
        stopWG.Add(1)
        go stopComponent(component, &stopWG)
    }
    stopWG.Wait()
    fmt.Println("All components stopped. Server is shut down.")
}

在这个示例中,我们使用WaitGroup来管理服务器组件的启动和关闭过程。在启动阶段,等待所有组件启动完成后才表示服务器开始运行;在关闭阶段,等待所有组件关闭完成后才表示服务器完全关闭。

测试并发代码

在编写并发代码的测试时,WaitGroup可以帮助我们确保所有并发操作都已完成,从而得到准确的测试结果。例如,在测试一个并发安全的计数器时,可以启动多个goroutine同时对计数器进行操作,然后使用WaitGroup等待所有操作完成后检查计数器的值是否正确。

package main

import (
    "fmt"
    "sync"
    "testing"
)

type Counter struct {
    value int
    mutex sync.Mutex
}

func (c *Counter) Increment() {
    c.mutex.Lock()
    c.value++
    c.mutex.Unlock()
}

func (c *Counter) GetValue() int {
    c.mutex.Lock()
    defer c.mutex.Unlock()
    return c.value
}

func TestConcurrentCounter(t *testing.T) {
    var wg sync.WaitGroup
    counter := Counter{}
    numGoroutines := 100
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
    }
    wg.Wait()
    if counter.GetValue() != numGoroutines {
        t.Errorf("Expected value %d, got %d", numGoroutines, counter.GetValue())
    } else {
        fmt.Println("Concurrent counter test passed.")
    }
}

在上述测试代码中,我们启动了100个goroutine并发调用counter.Increment()方法,通过WaitGroup等待所有goroutine完成操作后,检查计数器的值是否与预期相符,从而验证并发计数器的正确性。

通过以上对WaitGroup的核心概念、使用方法、实现原理、常见问题、与其他同步机制的比较以及实际应用场景的详细介绍,相信你对Go语言中的WaitGroup有了更深入全面的理解,能够在实际的并发编程项目中灵活准确地使用它来实现高效的同步操作。