前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >字节开源Go协程池 gopool

字节开源Go协程池 gopool

作者头像
王小明_HIT
发布2024-04-17 18:50:25
940
发布2024-04-17 18:50:25
举报
文章被收录于专栏:程序员奇点程序员奇点

字节开源Go协程池gopool

Java 中线程池,也支持自定义线程池,为啥 Golang 官方没有提供协程池的实现?Golang 官方偏向轻量级的并发, 希望通过 go func() 解决问题。

问题

  • 协程数量不可控,在代码并发处理过程中,一不小心 ,go 出了数万个协程, goruntine 虽然轻量级的执行流程,但是不限制的大量创建 goruntine ,对系统性能影响会很大,一个 goruntine 初始栈内存为 2KB,如果新建过多协程,过多 goruntine,内存会达到G级别,如何让协程数可控,是一个问题。
  • 协程泄漏问题,如果协程的bug,导致协程无法被回收,日积月累,可能导致程序崩溃,需要有工具避免协程泄漏问题。

先写一个协程池

一般来说,用 waitGroup 结合 channel ,可以实现一个协程池的功能。一个协程池,一般要具有如下三个功能:

  • 提交任务
  • 启动协程
  • 等待协程执行结束
代码语言:javascript
复制
package main

import (
    "fmt"
    "sync"
    "testing"
)

// 任务结构体
type Task struct {
    ID int
    // 任务
    Job func()
}

// 协程池结构体
type Pool struct {
    // 任务通道
    taskChan chan Task
    // 工作协程数量
    workerCount int
    // 等待组
    wg sync.WaitGroup
}

// 创建协程池
func NewPool(workerCount int) *Pool {
    workChannel := make(chan Task, workerCount)
    return &Pool{
       taskChan:    workChannel,
       workerCount: workerCount,
       wg:          sync.WaitGroup{},
    }
}

// 向协程池提交任务
func (p *Pool) SubmitTask(task Task) {
    p.taskChan <- task
    p.wg.Add(1)
}

// 启动工作协程
func (p *Pool) StartWorkers() {
    for i := 0; i < p.workerCount; i++ {
       go p.worker()
    }
}

// 工作协程
func (p *Pool) worker() {
    for task := range p.taskChan {
       defer p.wg.Done()
       fmt.Printf("Worker received task %d\n", task.ID)
       task.Job()
       fmt.Printf("Worker completed task %d\n", task.ID)
    }
}

func TestThreadPool(t *testing.T) {
    // 创建一个协程池,设置工作协程数量为 5
    pool := NewPool(5)

    // 提交任务到协程池
    for i := 1; i < 5; i++ {
       task := Task{
          ID: i,
          Job: func() {
             fmt.Printf("Task %d is running\n", i)
          },
       }
       pool.SubmitTask(task)
    }

    // 启动工作协程
    pool.StartWorkers()

    // 等待所有任务完成
    pool.wg.Wait()
}

执行结果:

代码语言:javascript
复制
=== RUN   TestThreadPool
Worker received task 1
Task 5 is running
Worker completed task 1
Worker received task 4
Task 5 is running
Worker completed task 4
Worker received task 2
Task 5 is running
Worker completed task 2
Worker received task 3
Task 5 is running
Worker completed task 3
--- PASS: TestThreadPool (0.00s)
PASS

优化一下上面的代码:

  • 将提交任务和协程池启动放一块
  • 引入 ctx, 其中某个协程错误,取消整个协程。
代码语言:javascript
复制
package utils

import (
    "context"
    "sync"
)

// Semaphore 使用waitGroup和channel实现并发同时控制最大并发量
// 参考golang.org/x/sync.errgroup实现返回err功能
type Semaphore struct {
    c       chan struct{}
    wg      sync.WaitGroup
    cancel  func()
    errOnce sync.Once
    err     error
}

func NewSemaphore(maxSize int) *Semaphore {
    return &Semaphore{
       c: make(chan struct{}, maxSize),
    }
}

func NewSemaphoreWithContext(ctx context.Context, maxSize int) (*Semaphore, context.Context) {
    ctx, cancel := context.WithCancel(ctx)
    return &Semaphore{
       c:      make(chan struct{}, maxSize),
       cancel: cancel,
    }, ctx
}

func (s *Semaphore) Go(f func() error) {
    s.wg.Add(1)
    s.c <- struct{}{}
    go func() {
       defer func() {
          if err := recover(); err != nil {
          }
       }()
       defer func() {
          <-s.c
          s.wg.Done()
       }()
       if err := f(); err != nil {
          s.errOnce.Do(func() {
             s.err = err
             if s.cancel != nil {
                s.cancel()
             }
          })
       }
    }()
}

func (s *Semaphore) Wait() error {
    s.wg.Wait()
    if s.cancel != nil {
       s.cancel()
    }
    return s.err
}

测试代码:

代码语言:javascript
复制
package utils

import (
    "math"
    "testing"
    "time"

    "github.com/bmizerany/assert"
)

func sleep1s() error {
    time.Sleep(time.Second)
    return nil
}

func TestSemaphore(t *testing.T) {
    // 最大并发 >= 执行任务数量
    sema := NewSemaphore(4)
    now := time.Now()
    for i := 0; i < 4; i++ {
       sema.Go(sleep1s)
    }
    err := sema.Wait()
    assert.Equal(t, nil, err)
    sec := math.Round(time.Since(now).Seconds())
    assert.Equal(t, 1, int(sec))

    // 设置最大并发为2
    sema = NewSemaphore(2)
    now = time.Now()
    for i := 0; i < 4; i++ {
       sema.Go(sleep1s)
    }
    err = sema.Wait()
    assert.Equal(t, nil, err)
    sec = math.Round(time.Since(now).Seconds())
    assert.Equal(t, 2, int(sec))
}

sync.pool

https://github.com/bytedance/gopkg/tree/develop/util/gopool

原理简介

原理和 Java 线程池原理有点类似

工作流程

  • 工作协程(workerPool):可以设置协程池中的工作协程数(cap)。
代码语言:javascript
复制
// 如果没使用 NewPool方法创建协程池 会默认 init 建一个 default pool
func init() {
    initMetrics()
    defaultPool = NewPool("gopool.DefaultPool", 10000, NewConfig())
}

func NewPool(name string, cap int32, config *Config) Pool {
    p := &pool{
       name:   name,
       cap:    cap,
       config: config,
    }
    return p
}
  • 任务队列(taskPool):用于存放待执行任务的队列,当核心线程都在执行任务时,新的任务会被放入任务队列中等待。
代码语言:javascript
复制
var taskPool sync.Pool

func init() {
    taskPool.New = newTask
}

func newTask() interface{} {
    return &task{}
}

工作流程如下:

  1. 当任务到达时,会将任务加入到工作队列中队尾。
  2. 如果 task 任务数量大于阈值,阈值默认是1且目前的 worker(工作协程)数量小于上限 p.cap 或者没有工作协程,会立即执行任务。
代码语言:javascript
复制
func (p *pool) CtxGo(ctx context.Context, f func()) {
    t := taskPool.Get().(*task)
    t.ctx = ctx
    t.f = f
    p.taskLock.Lock()
    if p.taskHead == nil {
       p.taskHead = t
       p.taskTail = t
    } else {
       p.taskTail.next = t
       p.taskTail = t
    }
    p.taskLock.Unlock()
    atomic.AddInt32(&p.taskCount, 1)
    // 如果 pool 已经被关闭了,就 panic
    if atomic.LoadInt32(&p.closed) == 1 {
       panic("use closed pool")
    }
    // 满足以下两个条件:
    // 1. task 数量大于阈值
    // 2. 目前的 worker 数量小于上限 p.cap(工作协程数)
    // 或者目前没有 worker
    if (atomic.LoadInt32(&p.taskCount) >= p.config.ScaleThreshold && p.WorkerCount() < atomic.LoadInt32(&p.cap)) || p.WorkerCount() == 0 {
       p.incWorkerCount()
       w := workerPool.Get().(*worker)
       w.pool = p
       w.run()
    }
}
  1. 通过 for 循环是从工作队列中取队头任务,然后移动队头指向链表下一节点,执行任务,任务完成后做清理,直至任务队列中没有任务需要执行,协程 return
代码语言:javascript
复制
func (w *worker) run() {
    go func() {
       for {
          //select {
          //case <-w.stopChan:
          // w.close()
          // return
          //default:
          var t *task
          w.pool.taskLock.Lock()
          if w.pool.taskHead != nil {
             t = w.pool.taskHead
             w.pool.taskHead = w.pool.taskHead.next
             atomic.AddInt32(&w.pool.taskCount, -1)
          }
          if t == nil {
             // 如果没有任务要做了,就释放资源,退出
             w.close()
             w.pool.taskLock.Unlock()
             w.Recycle()
             return
          }
          w.pool.taskLock.Unlock()
          func() {
             defer func() {
                if r := recover(); r != nil {
                   logs.CtxFatal(t.ctx, "GOPOOL: panic in pool: %s: %v: %s", w.pool.name, r, debug.Stack())
                   if w.pool.config.EnablePanicMetrics {
                      panicMetricsClient.EmitCounter(panicKey, 1, metrics.T{Name: "pool", Value: w.pool.name})
                   }
                   w.pool.panicHandler(t.ctx, r)
                }
             }()
             t.f()
          }()
          t.Recycle()
          //}
       }
    }()
}

可能会问,为啥要写个死循环去遍历,假设不写 for 循环, 如果一个任务,run 一次,就创建一个工作协程,这个开销成本比较高,通过循环变了任务队列的方式,不断去取,可以避免创建一些不必要的工作协程。

举个例子,假设有 4个任务,任务1 执行,开启了一个工作协程1, 任务2 执行,开启了一个工作协程2,任务3执行,开启了一个工作协程3, 任务4来了,此时工作协程1执行完毕,去取任务4执行。这样的话,4个任务,只需要3个工作协程,如果工作协程执行足够快,工作协程数会更少。

实践

场景:捞取2个月的数据,然后导出 捞取一个月的动账明细数据,然后进行导出,原流程是一个开始时间,一个结束时间,每次捞取10分钟的数据,每次加10分钟,循环处理。改为并发流程后,先将时间按10分钟分段,每一段做为一个任务,交给协程池去跑。最后再对结果进行汇总。项目实测,导出效率提升10倍以上。

参考资料

https://github.com/bytedance/gopkg/tree/develop/util/gopool

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 程序员奇点 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 字节开源Go协程池gopool
    • 问题
      • 先写一个协程池
        • sync.pool
          • 原理简介
          • 工作流程
          • 工作流程如下:
        • 实践
          • 参考资料
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档