MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go语言中的WaitGroup实现并发任务同步

2023-02-095.1k 阅读

Go语言并发编程基础

在深入探讨 WaitGroup 之前,先简要回顾一下Go语言并发编程的基础知识。Go语言以其轻量级的并发模型——goroutine,使得并发编程变得相对简单。

goroutine简介

goroutine 是Go语言中实现并发的核心概念,它类似于线程,但又有着本质的区别。与传统线程相比,goroutine非常轻量级,创建和销毁的开销极小。在Go程序中,只需在函数调用前加上 go 关键字,就可以让该函数在一个新的goroutine中运行。例如:

package main

import (
    "fmt"
    "time"
)

func printNumbers() {
    for i := 1; i <= 5; i++ {
        fmt.Println("Number:", i)
        time.Sleep(100 * time.Millisecond)
    }
}

func printLetters() {
    for i := 'a'; i <= 'e'; i++ {
        fmt.Println("Letter:", string(i))
        time.Sleep(100 * time.Millisecond)
    }
}

func main() {
    go printNumbers()
    go printLetters()

    time.Sleep(1000 * time.Millisecond)
}

在上述代码中,printNumbersprintLetters 函数分别在两个不同的goroutine中执行。main 函数启动这两个goroutine后,不会等待它们完成,而是继续执行后续代码。这里通过 time.Sleep 函数来确保 main 函数在两个goroutine执行完成之前不会退出。

并发与并行

并发(Concurrency)和并行(Parallelism)虽然在日常使用中有时会被混淆,但它们有着不同的含义。

  • 并发:指的是程序能够在同一时间段内处理多个任务,但并不意味着这些任务是同时执行的。在单核CPU系统中,操作系统通过时间片轮转等调度算法,使得多个任务看似同时执行。在Go语言中,多个goroutine可以在单个操作系统线程上多路复用,实现并发执行。
  • 并行:意味着多个任务在同一时刻真正地同时执行,这通常需要多核CPU的支持。Go语言的运行时系统(runtime)会将多个goroutine调度到多个操作系统线程上,进而充分利用多核CPU的优势,实现并行执行。

共享内存与数据竞争

在并发编程中,如果多个goroutine同时访问和修改共享数据,就可能引发数据竞争(Data Race)问题。数据竞争会导致程序出现不可预测的行为,如结果不一致、程序崩溃等。例如:

package main

import (
    "fmt"
)

var counter int

func increment() {
    counter++
}

func main() {
    for i := 0; i < 1000; i++ {
        go increment()
    }

    fmt.Println("Counter:", counter)
}

在这段代码中,多个goroutine同时调用 increment 函数对 counter 变量进行自增操作。由于没有任何同步机制,不同goroutine对 counter 的读写操作可能会相互干扰,导致最终输出的 counter 值并不是预期的1000。为了解决数据竞争问题,Go语言提供了多种同步机制,WaitGroup 就是其中之一。

WaitGroup的基本概念

WaitGroup 是Go标准库 sync 包中的一个类型,用于实现多个goroutine之间的同步。它的核心功能是阻塞一个或多个goroutine,直到一组特定的goroutine全部完成工作。

WaitGroup的结构

WaitGroup 的结构定义如下:

// src/sync/waitgroup.go
type WaitGroup struct {
    noCopy noCopy

    state1 uint64
    state2 uint32
}

虽然其内部结构看起来并不复杂,但却蕴含着丰富的功能。state1state2 这两个字段用于存储 WaitGroup 的状态信息,包括等待的goroutine数量以及已完成的goroutine数量等。noCopy 类型则是一个空结构体,用于防止 WaitGroup 被意外复制,因为 WaitGroup 在被复制后可能会导致同步逻辑出错。

WaitGroup的方法

WaitGroup 提供了三个主要方法:AddDoneWait

  • Add方法
    • 功能:用于向 WaitGroup 中添加需要等待的goroutine数量。它接受一个整数参数 delta,将 WaitGroup 内部的计数器增加 delta
    • 用法示例
var wg sync.WaitGroup
wg.Add(2) // 添加两个需要等待的goroutine
  • Done方法
    • 功能:用于标记一个goroutine已经完成工作。它实际上是 Add(-1) 的快捷方式,每调用一次 DoneWaitGroup 内部的计数器就会减1。
    • 用法示例
defer wg.Done() // 在goroutine结束前调用Done
  • Wait方法
    • 功能:阻塞当前goroutine,直到 WaitGroup 内部的计数器变为0,即所有需要等待的goroutine都调用了 Done 方法。
    • 用法示例
wg.Wait() // 等待所有goroutine完成

使用WaitGroup实现简单并发任务同步

下面通过一些具体的代码示例来展示如何使用 WaitGroup 实现简单的并发任务同步。

示例一:多个goroutine并行计算

假设我们需要计算一组数字的平方,并将结果打印出来。可以使用多个goroutine并行处理这些数字,然后通过 WaitGroup 等待所有计算完成。

package main

import (
    "fmt"
    "sync"
)

func square(wg *sync.WaitGroup, num int) {
    defer wg.Done()
    result := num * num
    fmt.Printf("Square of %d is %d\n", num, result)
}

func main() {
    var wg sync.WaitGroup
    numbers := []int{1, 2, 3, 4, 5}

    for _, num := range numbers {
        wg.Add(1)
        go square(&wg, num)
    }

    wg.Wait()
    fmt.Println("All squares calculated.")
}

在这个示例中,square 函数负责计算一个数字的平方并打印结果。在 main 函数中,遍历 numbers 切片,为每个数字启动一个新的goroutine来执行 square 函数,并通过 wg.Add(1) 添加需要等待的goroutine数量。每个 square 函数在结束前调用 defer wg.Done() 标记自己已完成工作。最后,wg.Wait() 阻塞 main 函数,直到所有计算平方的goroutine都完成。

示例二:多个goroutine读取文件内容

假设我们有多个文件,需要同时读取这些文件的内容,并在所有文件读取完成后进行一些汇总操作。

package main

import (
    "fmt"
    "io/ioutil"
    "os"
    "sync"
)

func readFile(wg *sync.WaitGroup, filePath string) {
    defer wg.Done()
    data, err := ioutil.ReadFile(filePath)
    if err != nil {
        fmt.Printf("Error reading %s: %v\n", filePath, err)
        return
    }
    fmt.Printf("Content of %s:\n%s\n", filePath, data)
}

func main() {
    var wg sync.WaitGroup
    filePaths := []string{"file1.txt", "file2.txt", "file3.txt"}

    for _, filePath := range filePaths {
        if _, err := os.Stat(filePath); os.IsNotExist(err) {
            fmt.Printf("%s does not exist.\n", filePath)
            continue
        }
        wg.Add(1)
        go readFile(&wg, filePath)
    }

    wg.Wait()
    fmt.Println("All files read.")
}

在这个示例中,readFile 函数负责读取指定文件的内容并打印。main 函数中遍历 filePaths 切片,为每个存在的文件启动一个goroutine来执行 readFile 函数,并添加到 WaitGroup 中。每个 readFile 函数完成读取后调用 wg.Done()wg.Wait() 确保在所有文件读取完成后才打印 “All files read.”。

WaitGroup的高级应用场景

除了上述简单的并发任务同步场景,WaitGroup 在一些更复杂的场景中也能发挥重要作用。

场景一:并发任务分组

有时候,我们可能需要将多个goroutine分成不同的组,并分别等待每组任务完成。例如,在一个数据分析程序中,可能有一组goroutine负责数据采集,另一组负责数据清洗,还有一组负责数据分析。可以使用多个 WaitGroup 来实现这种分组同步。

package main

import (
    "fmt"
    "sync"
    "time"
)

func dataCollection(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Data collection started.")
    time.Sleep(2 * time.Second)
    fmt.Println("Data collection finished.")
}

func dataCleaning(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Data cleaning started.")
    time.Sleep(3 * time.Second)
    fmt.Println("Data cleaning finished.")
}

func dataAnalysis(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Data analysis started.")
    time.Sleep(4 * time.Second)
    fmt.Println("Data analysis finished.")
}

func main() {
    var collectionWG sync.WaitGroup
    var cleaningWG sync.WaitGroup
    var analysisWG sync.WaitGroup

    collectionWG.Add(3)
    cleaningWG.Add(2)
    analysisWG.Add(1)

    for i := 0; i < 3; i++ {
        go dataCollection(&collectionWG)
    }

    for i := 0; i < 2; i++ {
        go dataCleaning(&cleaningWG)
    }

    go dataAnalysis(&analysisWG)

    collectionWG.Wait()
    fmt.Println("All data collection tasks completed.")

    cleaningWG.Wait()
    fmt.Println("All data cleaning tasks completed.")

    analysisWG.Wait()
    fmt.Println("All data analysis tasks completed.")

    fmt.Println("Overall data processing finished.")
}

在这个示例中,定义了三个 WaitGroup,分别用于数据采集、数据清洗和数据分析任务组。为每个任务组的goroutine添加相应的等待数量,并在每个goroutine完成后调用 Done 方法。通过依次调用每个 WaitGroupWait 方法,确保每组任务按顺序完成。

场景二:动态添加goroutine

在某些情况下,可能需要在程序运行过程中动态地添加新的goroutine,并让主goroutine等待所有动态添加的goroutine完成。例如,在一个网络爬虫程序中,可能根据初始页面的链接数量动态启动新的爬虫goroutine。

package main

import (
    "fmt"
    "sync"
)

func crawl(url string, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Crawling %s\n", url)
    // 模拟爬虫操作
}

func main() {
    var wg sync.WaitGroup
    initialURLs := []string{"http://example.com", "http://another-example.com"}

    for _, url := range initialURLs {
        wg.Add(1)
        go crawl(url, &wg)
    }

    // 假设这里根据初始页面的链接又发现了新的URL
    newURLs := []string{"http://new-url1.com", "http://new-url2.com"}
    for _, url := range newURLs {
        wg.Add(1)
        go crawl(url, &wg)
    }

    wg.Wait()
    fmt.Println("All crawling tasks completed.")
}

在这个示例中,首先为 initialURLs 中的每个URL启动一个爬虫goroutine并添加到 WaitGroup 中。之后,假设发现了新的URL,又为 newURLs 中的每个URL动态地启动新的爬虫goroutine并添加到 WaitGroup 中。wg.Wait() 确保所有爬虫goroutine都完成任务。

场景三:超时控制

在实际应用中,有时需要为等待 WaitGroup 的操作设置一个超时时间,以防止程序无限期地等待。可以结合 context.Contexttime.After 来实现这一功能。

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func task(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Task started.")
    time.Sleep(3 * time.Second)
    fmt.Println("Task finished.")
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)

    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()

    go task(&wg)

    select {
    case <-ctx.Done():
        fmt.Println("Timeout waiting for task.")
    case <-time.After(100 * time.Millisecond):
        // 防止select语句阻塞,这里只是一个占位
    }

    wg.Wait()
    fmt.Println("Task completed within timeout or after timeout check.")
}

在这个示例中,使用 context.WithTimeout 创建一个带有2秒超时的 context.Context。在 select 语句中,通过监听 ctx.Done() 通道来判断是否超时。如果超时,打印 “Timeout waiting for task.”。无论是否超时,wg.Wait() 仍然会等待任务完成,最后打印 “Task completed within timeout or after timeout check.”。

WaitGroup实现原理剖析

了解 WaitGroup 的使用方法后,深入探究其实现原理有助于更好地理解和应用它。

内部状态表示

正如前面提到的,WaitGroup 的内部状态由 state1state2 两个字段表示。state1 是一个64位无符号整数,state2 是一个32位无符号整数。state1 的高32位用于存储等待的goroutine数量(即 Add 方法增加的数量),低32位用于存储已完成的goroutine数量(即 Done 方法减少的数量)。state2 则用于存储信号量,用于控制阻塞和唤醒goroutine。

Add方法实现

Add 方法的实现如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
        return
    }
    // 所有goroutine都已完成,唤醒等待的goroutine
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

Add 方法中,首先通过 wg.state() 获取 state1state2 的指针。然后使用 atomic.AddUint64 原子操作将 delta 左移32位后加到 state1 上。接着检查 delta 是否为负数,如果是则抛出恐慌,因为不允许出现负的等待计数器。同时检查是否存在并发调用 AddWait 的情况,如果存在也抛出恐慌。如果 v 大于0(即还有等待的goroutine)或者 w 为0(表示没有已完成的goroutine),则直接返回。否则,通过 runtime_Semrelease 释放信号量,唤醒等待的goroutine。

Done方法实现

Done 方法实际上是 Add(-1) 的快捷方式,其实现如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

通过调用 Add(-1)Done 方法将等待的goroutine数量减1。

Wait方法实现

Wait 方法的实现如下:

// src/sync/waitgroup.go
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        if v == 0 {
            return
        }
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            if *statep&1 != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

Wait 方法中,同样先获取 state1state2 的指针。然后在一个无限循环中,通过 atomic.LoadUint64 原子加载 state1 的值,并检查等待的goroutine数量 v 是否为0。如果为0,则表示所有goroutine都已完成,直接返回。否则,尝试使用 atomic.CompareAndSwapUint64 原子操作将 state1 的值增加1。如果操作成功,调用 runtime_Semacquire 获取信号量,进入阻塞状态。如果在获取信号量后发现 state1 的最低位为1,表示 WaitGroup 在之前的 Wait 操作返回前被重用,抛出恐慌。

WaitGroup使用注意事项

在使用 WaitGroup 时,有一些注意事项需要牢记,以避免出现难以调试的问题。

避免重复使用已完成的WaitGroup

一旦 WaitGroup 的计数器变为0,并且所有等待的goroutine都已被唤醒,就不应该再次使用该 WaitGroup 来等待新的goroutine。例如:

package main

import (
    "fmt"
    "sync"
)

func task(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Task completed.")
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go task(&wg)
    wg.Wait()

    // 不应该在这里再次使用wg
    wg.Add(1)
    go task(&wg)
    wg.Wait()
}

在上述代码中,第一次 wg.Wait() 完成后,再次使用 wg 来等待新的goroutine,这可能会导致未定义行为。正确的做法是重新创建一个新的 WaitGroup

防止Add负数或并发Add和Wait

Add 方法不应该传入负数,除非是通过 Done 方法间接调用。同时,要避免在并发环境下同时调用 AddWait 方法,这可能会导致程序出现恐慌。例如:

package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    go func() {
        wg.Add(-1) // 不应该直接调用Add(-1)
    }()
    wg.Wait()

    var wg2 sync.WaitGroup
    go func() {
        wg2.Add(1)
        wg2.Wait()
    }()
    wg2.Add(1) // 并发调用Add和Wait可能导致恐慌
}

在第一段代码中,直接调用 wg.Add(-1) 是错误的,应该使用 wg.Done()。在第二段代码中,并发调用 AddWait 可能会触发恐慌,需要通过适当的同步机制来避免这种情况。

注意WaitGroup的生命周期

确保 WaitGroup 的生命周期与需要等待的goroutine相匹配。如果 WaitGroup 在所有相关goroutine完成之前就被销毁,可能会导致goroutine泄漏。例如:

package main

import (
    "fmt"
    "sync"
    "time"
)

func longRunningTask(wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Println("Long running task started.")
    time.Sleep(5 * time.Second)
    fmt.Println("Long running task finished.")
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go longRunningTask(&wg)

    // 这里没有等待wg就退出了,可能导致goroutine泄漏
}

在上述代码中,main 函数在启动 longRunningTask 后没有调用 wg.Wait() 就直接退出,这可能会导致 longRunningTask 这个goroutine成为泄漏的goroutine。应该确保在 main 函数结束前调用 wg.Wait()

总结

WaitGroup 是Go语言并发编程中一个非常实用的工具,它能够有效地实现多个goroutine之间的同步。通过了解其基本概念、使用方法、高级应用场景、实现原理以及使用注意事项,开发者可以更加熟练和安全地使用 WaitGroup 来构建复杂的并发程序。在实际项目中,根据具体的需求合理运用 WaitGroup,可以充分发挥Go语言并发编程的优势,提高程序的性能和效率。同时,要时刻注意避免因不当使用 WaitGroup 而导致的各种问题,确保程序的稳定性和正确性。