MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

使用WaitGroup与Context管理复杂并发任务

2023-08-162.6k 阅读

Go 语言并发编程基础

在深入探讨 WaitGroupContext 之前,我们先来回顾一下 Go 语言并发编程的基础概念。Go 语言在设计之初就将并发编程作为其核心特性之一,通过 goroutinechannel 这两个强大的工具,使得编写高效且简洁的并发程序变得相对容易。

goroutine

goroutine 是 Go 语言中实现并发的轻量级线程。与传统的操作系统线程相比,goroutine 的创建和销毁成本极低,这使得我们可以轻松创建数以万计的 goroutine 来处理并发任务。例如,下面的代码创建了一个简单的 goroutine

package main

import (
    "fmt"
)

func printHello() {
    fmt.Println("Hello, goroutine!")
}

func main() {
    go printHello()
    fmt.Println("Main function")
}

在上述代码中,通过 go 关键字启动了一个 goroutine 来执行 printHello 函数。主函数并不会等待 printHello 函数执行完毕,而是继续向下执行并输出 "Main function"。由于 goroutine 的执行是异步的,在实际运行中,"Hello, goroutine!" 可能会在 "Main function" 之前或之后输出,这取决于调度器的调度策略。

channel

channel 是 Go 语言中用于在 goroutine 之间进行通信和同步的机制。它可以被看作是一个类型安全的管道,数据可以从一端发送,从另一端接收。例如:

package main

import (
    "fmt"
)

func sendData(ch chan int) {
    for i := 0; i < 5; i++ {
        ch <- i
    }
    close(ch)
}

func main() {
    ch := make(chan int)
    go sendData(ch)
    for data := range ch {
        fmt.Println("Received:", data)
    }
}

在这段代码中,sendData 函数通过 ch <- ichannel 发送数据,而主函数通过 for data := range chchannel 接收数据。当 sendData 函数执行完数据发送并关闭 channel 后,for data := range ch 循环会自动结束。

WaitGroup 原理与使用

WaitGroup 是 Go 语言标准库中用于等待一组 goroutine 完成的工具。它提供了一种简单而有效的方式来同步多个 goroutine 的执行,确保在所有相关的 goroutine 完成任务后,主程序才继续执行后续的逻辑。

WaitGroup 结构与方法

WaitGroup 的结构体定义如下(简化版,实际定义在标准库源码中更为复杂):

type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

WaitGroup 主要有三个方法:

  1. Add(delta int):增加等待组的计数器。delta 参数可以是正数或负数,但通常我们传递正数来表示需要等待的 goroutine 数量。
  2. Done():减少等待组的计数器,相当于 Add(-1)。一般在 goroutine 完成任务时调用。
  3. Wait():阻塞调用者,直到等待组的计数器变为零。

简单示例

下面通过一个简单的示例来展示 WaitGroup 的基本用法:

package main

import (
    "fmt"
    "sync"
)

func worker(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Worker %d started\n", id)
    // 模拟一些工作
    for i := 0; i < 1000000000; i++ {
        if i == 999999999 {
            fmt.Printf("Worker %d finished\n", id)
        }
    }
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 3

    for i := 1; i <= numWorkers; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }

    wg.Wait()
    fmt.Println("All workers have finished")
}

在这个示例中,我们创建了 3 个 goroutine 来模拟工作。在启动每个 goroutine 之前,通过 wg.Add(1) 增加等待组的计数器。每个 goroutine 执行 worker 函数,在函数结束时调用 wg.Done() 来减少计数器。主函数通过 wg.Wait() 等待所有 goroutine 完成,只有当所有 goroutine 都调用了 Done() 使得计数器变为零时,主函数才会继续执行并输出 "All workers have finished"。

错误处理与 WaitGroup

在实际应用中,goroutine 可能会发生错误。我们可以通过将错误返回并结合 WaitGroup 来处理这种情况。例如:

package main

import (
    "errors"
    "fmt"
    "sync"
)

var ErrTaskFailed = errors.New("task failed")

func worker(id int, wg *sync.WaitGroup, resultChan chan error) {
    defer wg.Done()
    fmt.Printf("Worker %d started\n", id)
    // 模拟一些工作
    if id == 2 {
        resultChan <- ErrTaskFailed
        return
    }
    fmt.Printf("Worker %d finished\n", id)
    resultChan <- nil
}

func main() {
    var wg sync.WaitGroup
    numWorkers := 3
    resultChan := make(chan error, numWorkers)

    for i := 1; i <= numWorkers; i++ {
        wg.Add(1)
        go worker(i, &wg, resultChan)
    }

    go func() {
        wg.Wait()
        close(resultChan)
    }()

    for err := range resultChan {
        if err != nil {
            fmt.Println("Error:", err)
            return
        }
    }
    fmt.Println("All workers have finished successfully")
}

在这个改进的示例中,worker 函数可以返回错误到 resultChan。主函数通过 WaitGroup 等待所有 goroutine 完成后关闭 resultChan,然后从 resultChan 中接收错误。如果有任何一个 goroutine 返回错误,主函数就会输出错误信息并结束;否则输出 "All workers have finished successfully"。

WaitGroup 的局限性

虽然 WaitGroup 在很多场景下非常实用,但它也有一些局限性。例如,WaitGroup 一旦创建,其计数器的值只能增加或减少,无法动态调整等待的 goroutine 数量。而且 WaitGroup 本身没有提供取消机制,如果需要提前终止一组 goroutine 的执行,WaitGroup 无法直接满足这个需求。这时候就需要引入 Context 来解决这些问题。

Context 原理与使用

Context 是 Go 1.7 引入的一个重要特性,用于在 goroutine 树中传递截止时间、取消信号和其他请求范围的值。它为管理复杂并发任务提供了一种优雅且强大的方式。

Context 接口与实现

Context 是一个接口,定义如下:

type Context interface {
    Deadline() (deadline time.Time, ok bool)
    Done() <-chan struct{}
    Err() error
    Value(key interface{}) interface{}
}
  • Deadline():返回截止时间。oktrue 时表示设置了截止时间,deadline 为截止时间点。
  • Done():返回一个只读的 channel,当 Context 被取消或超时时,这个 channel 会被关闭。
  • Err():返回 Context 被取消或超时的原因。如果 Done() 还未关闭,Err() 返回 nil
  • Value(key interface{}):返回与 key 关联的值,如果没有则返回 nil

Go 标准库提供了几个创建 Context 的函数,最常用的是 context.Background()context.TODO(),它们返回一个空的 Context,通常作为 Context 树的根节点。另外,context.WithCancel()context.WithDeadline()context.WithTimeout() 用于创建可取消或有截止时间的 Context

使用 Context 取消 goroutine

下面是一个使用 context.WithCancel() 取消 goroutine 的示例:

package main

import (
    "context"
    "fmt"
    "time"
)

func worker(ctx context.Context) {
    for {
        select {
        case <-ctx.Done():
            fmt.Println("Worker stopped")
            return
        default:
            fmt.Println("Worker is working...")
            time.Sleep(1 * time.Second)
        }
    }
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    go worker(ctx)

    time.Sleep(3 * time.Second)
    cancel()
    time.Sleep(1 * time.Second)
    fmt.Println("Main function finished")
}

在这个示例中,context.WithCancel(context.Background()) 创建了一个可取消的 Context 和对应的取消函数 cancelworker 函数通过 select 语句监听 ctx.Done() 信号。主函数在运行 3 秒后调用 cancel() 函数取消 Contextworker 函数监听到 ctx.Done() 信号后停止工作并退出。

使用 Context 设置截止时间

context.WithDeadline()context.WithTimeout() 用于设置 Context 的截止时间。context.WithTimeout() 实际上是基于 context.WithDeadline() 实现的,它接受一个超时时间参数。例如:

package main

import (
    "context"
    "fmt"
    "time"
)

func worker(ctx context.Context) {
    select {
    case <-ctx.Done():
        fmt.Println("Worker stopped:", ctx.Err())
        return
    default:
        fmt.Println("Worker is working...")
        time.Sleep(5 * time.Second)
        fmt.Println("Worker finished")
    }
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()
    go worker(ctx)

    time.Sleep(4 * time.Second)
    fmt.Println("Main function finished")
}

在这个例子中,context.WithTimeout(context.Background(), 3*time.Second) 设置了一个 3 秒的超时时间。worker 函数在执行过程中,如果超过 3 秒,ctx.Done() 信号会被触发,worker 函数会停止并输出停止原因。

Context 传递值

Context 还可以在 goroutine 树中传递值。例如:

package main

import (
    "context"
    "fmt"
)

type UserKey struct{}

func worker(ctx context.Context) {
    user, ok := ctx.Value(UserKey{}).(string)
    if ok {
        fmt.Printf("Worker: User is %s\n", user)
    } else {
        fmt.Println("Worker: No user found")
    }
}

func main() {
    ctx := context.WithValue(context.Background(), UserKey{}, "John")
    go worker(ctx)

    time.Sleep(1 * time.Second)
    fmt.Println("Main function finished")
}

在这个示例中,通过 context.WithValue(context.Background(), UserKey{}, "John") 将用户信息 "John" 附加到 Context 中。worker 函数通过 ctx.Value(UserKey{}) 获取这个值并进行处理。

使用 WaitGroup 与 Context 协同管理复杂并发任务

在实际的复杂并发场景中,我们通常需要同时使用 WaitGroupContext 来实现任务的同步和取消。

场景一:多个任务并行且可取消

假设我们有多个任务需要并行执行,并且可以随时取消。例如,我们要从多个数据源获取数据,在获取到足够的数据或者用户取消操作时停止。

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func fetchData(ctx context.Context, id int, wg *sync.WaitGroup, resultChan chan int) {
    defer wg.Done()
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("Fetch %d stopped\n", id)
            return
        default:
            fmt.Printf("Fetch %d working...\n", id)
            time.Sleep(1 * time.Second)
            resultChan <- id * 10
        }
    }
}

func main() {
    var wg sync.WaitGroup
    ctx, cancel := context.WithCancel(context.Background())
    resultChan := make(chan int, 10)
    numFetches := 3

    for i := 1; i <= numFetches; i++ {
        wg.Add(1)
        go fetchData(ctx, i, &wg, resultChan)
    }

    go func() {
        wg.Wait()
        close(resultChan)
    }()

    var totalData int
    for data := range resultChan {
        totalData += data
        fmt.Printf("Received data: %d, total: %d\n", data, totalData)
        if totalData >= 50 {
            cancel()
            break
        }
    }

    time.Sleep(1 * time.Second)
    fmt.Println("Main function finished")
}

在这个示例中,fetchData 函数模拟从数据源获取数据的任务。每个任务通过 WaitGroup 进行同步,并且可以通过 Context 取消。主函数通过监听 resultChan 来收集数据,当收集到的数据总和达到 50 时,调用 cancel() 取消所有任务。

场景二:任务链与超时控制

有时候我们的任务是一个任务链,前一个任务的结果作为后一个任务的输入,并且整个任务链有超时限制。例如:

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func task1(ctx context.Context, wg *sync.WaitGroup, resultChan chan int) {
    defer wg.Done()
    select {
    case <-ctx.Done():
        fmt.Println("Task1 stopped")
        return
    default:
        fmt.Println("Task1 is working...")
        time.Sleep(2 * time.Second)
        resultChan <- 10
    }
}

func task2(ctx context.Context, wg *sync.WaitGroup, inputChan <-chan int, resultChan chan int) {
    defer wg.Done()
    select {
    case <-ctx.Done():
        fmt.Println("Task2 stopped")
        return
    case input := <-inputChan:
        fmt.Printf("Task2 received input: %d\n", input)
        fmt.Println("Task2 is working...")
        time.Sleep(2 * time.Second)
        resultChan <- input * 2
    }
}

func main() {
    var wg sync.WaitGroup
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()

    task1ResultChan := make(chan int, 1)
    task2ResultChan := make(chan int, 1)

    wg.Add(1)
    go task1(ctx, &wg, task1ResultChan)

    wg.Add(1)
    go task2(ctx, &wg, task1ResultChan, task2ResultChan)

    go func() {
        wg.Wait()
        close(task1ResultChan)
        close(task2ResultChan)
    }()

    for result := range task2ResultChan {
        fmt.Printf("Final result: %d\n", result)
    }

    time.Sleep(1 * time.Second)
    fmt.Println("Main function finished")
}

在这个场景中,task1 先执行,其结果通过 task1ResultChan 传递给 task2。整个任务链设置了 3 秒的超时时间。如果 task1task2 在超时时间内未完成,Context 会取消任务,相关的 goroutine 会收到取消信号并停止执行。

常见问题与注意事项

Context 的传递

在复杂的 goroutine 调用链中,确保 Context 正确传递至关重要。每个 goroutine 都应该使用父 goroutine 传递下来的 Context 或者基于父 Context 创建新的 Context。如果 Context 传递错误,可能会导致 goroutine 无法正确接收取消信号或截止时间信息。

WaitGroup 的计数器操作

WaitGroup 计数器的操作要特别小心。确保在启动 goroutine 之前正确调用 Add 方法,并且在 goroutine 结束时调用 Done 方法。如果计数器操作不当,例如忘记调用 Add 或者多次调用 Done,可能会导致 Wait 方法永远阻塞或者提前返回。

资源释放

goroutine 因为 Context 取消或其他原因提前结束时,要确保相关的资源(如文件句柄、网络连接等)被正确释放。可以使用 defer 语句来确保资源在 goroutine 结束时被释放。

性能考虑

虽然 goroutine 是轻量级的,但过多的 goroutine 同时运行仍然可能导致性能问题。在使用 WaitGroupContext 管理并发任务时,要根据实际需求合理控制 goroutine 的数量,避免资源过度消耗。

通过合理使用 WaitGroupContext,我们可以有效地管理 Go 语言中的复杂并发任务,实现高效、健壮且可控制的并发程序。无论是在网络编程、分布式系统还是其他需要并发处理的场景中,这两个工具都能发挥重要作用。