MK
摩柯社区 - 一个极简的技术知识社区
AI 面试

Go Context的使用场景梳理

2021-09-016.2k 阅读

1. 控制并发操作的生命周期

在 Go 语言中,并发编程是其核心优势之一。然而,当我们启动多个 goroutine 进行并发操作时,如何有效地管理它们的生命周期是一个重要问题。Context 提供了一种优雅的方式来控制并发操作的取消和超时,从而避免资源浪费和潜在的内存泄漏。

1.1 取消操作

假设我们有一个复杂的任务,由多个 goroutine 协作完成,当其中一个子任务出现错误或者外部条件满足时,我们希望能够取消整个任务。以下是一个简单的示例:

package main

import (
    "context"
    "fmt"
    "time"
)

func worker(ctx context.Context, id int) {
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("Worker %d stopped\n", id)
            return
        default:
            fmt.Printf("Worker %d working\n", id)
            time.Sleep(100 * time.Millisecond)
        }
    }
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())

    for i := 1; i <= 3; i++ {
        go worker(ctx, i)
    }

    time.Sleep(500 * time.Millisecond)
    cancel()

    time.Sleep(200 * time.Millisecond)
}

在上述代码中,我们通过 context.WithCancel 创建了一个可取消的 Context。每个 worker goroutine 在 select 语句中监听 ctx.Done() 信号。当 cancel 函数被调用时,所有的 worker goroutine 都会收到取消信号并退出。

1.2 超时控制

有时,我们希望某个并发操作在一定时间内完成,否则自动取消。例如,发起一个 HTTP 请求并设置超时时间。

package main

import (
    "context"
    "fmt"
    "net/http"
    "time"
)

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()

    req, err := http.NewRequestWithContext(ctx, "GET", "https://example.com", nil)
    if err != nil {
        fmt.Println("Error creating request:", err)
        return
    }

    client := &http.Client{}
    resp, err := client.Do(req)
    if err != nil {
        if ctx.Err() == context.DeadlineExceeded {
            fmt.Println("Request timed out")
        } else {
            fmt.Println("Request error:", err)
        }
        return
    }
    defer resp.Body.Close()

    fmt.Println("Request successful")
}

在这个例子中,我们使用 context.WithTimeout 创建了一个带有超时时间的 Context,并将其传递给 http.NewRequestWithContext。如果请求在 2 秒内没有完成,ctx.Err() 将返回 context.DeadlineExceeded,我们可以据此进行相应的处理。

2. 传递请求范围的数据

在 Web 应用开发中,一个 HTTP 请求可能会涉及多个函数调用和多个 goroutine 的协作。Context 提供了一种在这些函数和 goroutine 之间传递请求范围数据的方式。

2.1 传递用户认证信息

假设我们有一个基于 JWT(JSON Web Token)的认证系统,在处理 HTTP 请求时,我们需要在不同的中间件和处理函数之间传递用户的认证信息。

package main

import (
    "context"
    "fmt"
    "net/http"
)

type userInfo struct {
    UserID string
    Role   string
}

func authMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 假设这里从 JWT 中解析出用户信息
        user := userInfo{
            UserID: "12345",
            Role:   "admin",
        }
        ctx := context.WithValue(r.Context(), "userInfo", user)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func userHandler(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    user, ok := ctx.Value("userInfo").(userInfo)
    if!ok {
        http.Error(w, "Unauthorized", http.StatusUnauthorized)
        return
    }
    fmt.Fprintf(w, "User ID: %s, Role: %s\n", user.UserID, user.Role)
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/user", authMiddleware(http.HandlerFunc(userHandler)))

    http.ListenAndServe(":8080", mux)
}

在上述代码中,authMiddleware 从 JWT 中解析出用户信息,并通过 context.WithValue 将用户信息附加到请求的 Context 中。userHandler 从 Context 中获取用户信息并进行相应的处理。

2.2 传递请求 ID

在分布式系统中,为了方便跟踪和调试,通常会为每个请求生成一个唯一的请求 ID。这个请求 ID 需要在整个请求处理过程中传递。

package main

import (
    "context"
    "fmt"
    "net/http"
    "strconv"
    "time"
)

func requestIDMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        requestID := strconv.FormatInt(time.Now().UnixNano(), 10)
        ctx := context.WithValue(r.Context(), "requestID", requestID)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func logHandler(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx := r.Context()
        requestID := ctx.Value("requestID")
        fmt.Printf("Request ID: %v\n", requestID)
        next.ServeHTTP(w, r)
    })
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/", logHandler(requestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        fmt.Fprintf(w, "Hello, World!")
    }))))

    http.ListenAndServe(":8080", mux)
}

在这个例子中,requestIDMiddleware 为每个请求生成一个唯一的请求 ID,并将其附加到 Context 中。logHandler 从 Context 中获取请求 ID 并进行日志记录。

3. 资源清理

在并发操作中,当一个操作被取消或者超时后,可能需要清理相关的资源。Context 可以与 defer 语句结合使用,确保资源得到正确的清理。

3.1 文件资源清理

假设我们在一个 goroutine 中打开一个文件进行读写操作,当操作被取消时,我们需要关闭文件以释放资源。

package main

import (
    "context"
    "fmt"
    "os"
    "time"
)

func fileOperation(ctx context.Context) {
    file, err := os.Open("test.txt")
    if err != nil {
        fmt.Println("Error opening file:", err)
        return
    }
    defer file.Close()

    select {
    case <-ctx.Done():
        fmt.Println("File operation cancelled")
        return
    default:
        // 模拟文件操作
        fmt.Println("Reading file...")
        time.Sleep(2 * time.Second)
    }
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
    defer cancel()

    go fileOperation(ctx)

    time.Sleep(2 * time.Second)
}

在上述代码中,fileOperation 函数打开一个文件,并在 select 语句中监听 ctx.Done() 信号。当操作被取消时,defer 语句会确保文件被关闭。

3.2 数据库连接清理

在数据库操作中,当一个数据库事务因为某种原因被取消时,我们需要关闭数据库连接以避免资源泄漏。

package main

import (
    "context"
    "fmt"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
    "time"
)

func databaseOperation(ctx context.Context) {
    dsn := "user:password@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local"
    db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
    if err != nil {
        fmt.Println("Error connecting to database:", err)
        return
    }
    defer func() {
        sqlDB, err := db.DB()
        if err != nil {
            fmt.Println("Error getting SQL DB:", err)
            return
        }
        sqlDB.Close()
    }()

    select {
    case <-ctx.Done():
        fmt.Println("Database operation cancelled")
        return
    default:
        // 模拟数据库操作
        fmt.Println("Querying database...")
        time.Sleep(2 * time.Second)
    }
}

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
    defer cancel()

    go databaseOperation(ctx)

    time.Sleep(2 * time.Second)
}

在这个例子中,databaseOperation 函数连接到数据库,并在 defer 语句中关闭数据库连接。当操作被取消时,数据库连接会被正确关闭。

4. 控制子 goroutine 的并发度

在某些情况下,我们需要控制同时运行的子 goroutine 的数量,以避免资源耗尽。Context 可以与 sync.WaitGroupchannel 结合使用来实现这一目的。

4.1 限制并发度示例

假设我们有一个任务队列,需要从队列中取出任务并并发处理,但同时要限制并发处理的任务数量。

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

func worker(ctx context.Context, id int, taskChan <-chan int, wg *sync.WaitGroup) {
    defer wg.Done()
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("Worker %d stopped\n", id)
            return
        case task, ok := <-taskChan:
            if!ok {
                return
            }
            fmt.Printf("Worker %d processing task %d\n", id, task)
            time.Sleep(100 * time.Millisecond)
        }
    }
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    var wg sync.WaitGroup
    const maxWorkers = 3
    taskChan := make(chan int, 10)

    for i := 1; i <= maxWorkers; i++ {
        wg.Add(1)
        go worker(ctx, i, taskChan, &wg)
    }

    for i := 1; i <= 10; i++ {
        taskChan <- i
    }
    close(taskChan)

    time.Sleep(500 * time.Millisecond)
    cancel()

    wg.Wait()
}

在上述代码中,我们通过 context.WithCancel 创建了一个可取消的 Context,并启动了 maxWorkersworker goroutine。每个 workertaskChan 中获取任务并处理。当 cancel 函数被调用时,所有的 worker goroutine 都会收到取消信号并退出。

5. 处理分布式系统中的请求

在分布式系统中,一个请求可能会跨越多个服务和节点。Context 可以在这些服务之间传递,以实现统一的请求控制和跟踪。

5.1 分布式追踪

假设我们有一个微服务架构,其中一个请求会经过多个服务。我们可以使用 Context 来传递追踪 ID,以便在整个系统中跟踪请求的路径。

// 服务 A
package main

import (
    "context"
    "fmt"
    "net/http"
    "strconv"
    "time"
)

func serviceA(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        traceID := strconv.FormatInt(time.Now().UnixNano(), 10)
        ctx := context.WithValue(r.Context(), "traceID", traceID)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func forwardToServiceB(ctx context.Context) {
    // 模拟向服务 B 转发请求,并传递 Context
    traceID := ctx.Value("traceID")
    fmt.Printf("Service A forwarding request with trace ID: %v\n", traceID)
    // 这里实际应该是 HTTP 调用等操作
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/", serviceA(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        forwardToServiceB(r.Context())
        fmt.Fprintf(w, "Response from Service A")
    })))

    http.ListenAndServe(":8080", mux)
}

// 服务 B
package main

import (
    "context"
    "fmt"
    "net/http"
)

func serviceB(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    traceID := ctx.Value("traceID")
    fmt.Printf("Service B received request with trace ID: %v\n", traceID)
    fmt.Fprintf(w, "Response from Service B")
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/", http.HandlerFunc(serviceB))

    http.ListenAndServe(":8081", mux)
}

在这个例子中,服务 A 生成一个追踪 ID 并附加到 Context 中,然后将请求转发到服务 B。服务 B 从 Context 中获取追踪 ID 并进行相应的日志记录。

5.2 分布式请求取消

在分布式系统中,当一个请求的某个部分出现错误或者超时,我们希望能够取消整个请求涉及的所有操作。可以通过传递 Context 来实现这一点。

// 服务 A
package main

import (
    "context"
    "fmt"
    "net/http"
    "time"
)

func serviceA(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second)
        defer cancel()
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

func forwardToServiceB(ctx context.Context) {
    // 模拟向服务 B 转发请求,并传递 Context
    fmt.Printf("Service A forwarding request\n")
    // 这里实际应该是 HTTP 调用等操作
    time.Sleep(3 * time.Second)
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/", serviceA(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        go forwardToServiceB(r.Context())
        fmt.Fprintf(w, "Response from Service A")
    })))

    http.ListenAndServe(":8080", mux)
}

// 服务 B
package main

import (
    "context"
    "fmt"
    "net/http"
)

func serviceB(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    select {
    case <-ctx.Done():
        fmt.Println("Service B request cancelled")
        http.Error(w, "Request cancelled", http.StatusRequestTimeout)
        return
    default:
        fmt.Println("Service B processing request")
        fmt.Fprintf(w, "Response from Service B")
    }
}

func main() {
    mux := http.NewServeMux()
    mux.Handle("/", http.HandlerFunc(serviceB))

    http.ListenAndServe(":8081", mux)
}

在这个例子中,服务 A 为请求设置了一个超时时间,并将带有超时的 Context 传递给服务 B。如果服务 A 的操作超时,服务 B 会收到取消信号并进行相应的处理。

6. 在测试中使用 Context

在编写测试时,Context 可以帮助我们模拟不同的场景,例如取消操作和超时,从而更好地验证代码的正确性。

6.1 测试取消操作

假设我们有一个函数,它在收到取消信号时会停止工作。我们可以使用 Context 来测试这个函数的取消逻辑。

package main

import (
    "context"
    "fmt"
    "testing"
    "time"
)

func worker(ctx context.Context, id int) {
    for {
        select {
        case <-ctx.Done():
            fmt.Printf("Worker %d stopped\n", id)
            return
        default:
            fmt.Printf("Worker %d working\n", id)
            time.Sleep(100 * time.Millisecond)
        }
    }
}

func TestWorkerCancel(t *testing.T) {
    ctx, cancel := context.WithCancel(context.Background())

    go worker(ctx, 1)

    time.Sleep(300 * time.Millisecond)
    cancel()

    time.Sleep(200 * time.Millisecond)
}

在上述测试代码中,我们创建了一个可取消的 Context,并启动一个 worker goroutine。然后,我们在一段时间后调用 cancel 函数,并等待一段时间以确保 worker goroutine 正确收到取消信号并退出。

6.2 测试超时操作

对于带有超时逻辑的函数,我们可以使用 Context 来测试其超时处理是否正确。

package main

import (
    "context"
    "fmt"
    "testing"
    "time"
)

func longRunningOperation(ctx context.Context) {
    select {
    case <-ctx.Done():
        fmt.Println("Operation cancelled")
        return
    case <-time.After(500 * time.Millisecond):
        fmt.Println("Operation completed")
    }
}

func TestLongRunningOperationTimeout(t *testing.T) {
    ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
    defer cancel()

    go longRunningOperation(ctx)

    time.Sleep(500 * time.Millisecond)
}

在这个测试中,我们创建了一个带有超时时间的 Context,并启动一个 longRunningOperation goroutine。由于超时时间设置为 300 毫秒,而操作本身需要 500 毫秒,因此 ctx.Done() 信号会先被触发,从而验证了超时逻辑的正确性。

通过以上对 Go Context 使用场景的梳理,我们可以看到 Context 在 Go 语言的并发编程、Web 开发、分布式系统等多个领域都发挥着重要作用。合理使用 Context 可以使我们的代码更加健壮、易于维护和扩展。