实现异步并发 worker 队列

Posted April 11, 2022 by Yusank ‐ 8 min read

在开发 broadcast 功能的时候,碰到一个比较棘手的问题,需要并发执行多个 worker 来讲 broadcast 消息推送到所有在线用户,同时我希望能控制并发数量。本文记录实现一个异步并发 worker 队列的过程。

前言

以往遇到类似的问题我都会借助 sync.WaitGroupchannel 的方式去做,实现方式也比较简单。大致思路如下:

 1type LimitedWaitGroup struct {
 2    wg *sync.WaitGroup
 3    ch chan int
 4}
 5
 6func NewLimitedWaitGroup(size int) *LimitedWaitGroup {
 7    return &LimitedWaitGroup{
 8        wg : new(sync.WaitGroup),
 9        ch : make(chan int, size)
10    }
11}
12
13func (w *LimitedWaitGroup) Add(f func()) {
14    // wait if channel is full
15    w.ch <- 1
16    w.wg.Add(1)
17    go func() {
18        defer w.done()
19        f()
20    }()
21}
22
23func (w *LimitedWaitGroup) done() {
24    <-w.ch
25    w.wg.Done()
26}
27
28func (w *LimitedWaitGroup) Wait() {
29    w.wg.Wait()
30}

这样能解决我大部分的简单需求,但是现在我想要的能力用这个简单的 LimitedWaitGroup 无法完全满足,所以重新设计了一个 worker pool 的概念来满足我现在以及以后类似的需求。

设计

需求整理

首先将目前想到的需求以及其优先级列出来:

高优先级:

  1. worker pool 支持设置 size,防止 worker 无限增多
  2. 任务并发执行且能指定并发数
  3. 当 worker 达到上限时,新的任务在一定范围内支持排队等待(即 limited queue
  4. 支持捕获任务错误
  5. 排队中的任务应该按顺序调度执行

低优先级:

  1. 任务支持实时状态更新
  2. 任务可以外部等待完成(类似 waitGroup.Done()
  3. 当空闲 worker 小于指定并发数时,支持占用空闲 worker 部分运行(如当前剩余 3 个 worker 可用,但是新的任务需要 5 个并发,则尝试先占用这 3 个worker,并在运行过程中继续监听 pool 空闲出来的 worker 并尝试去占用)

小结
列出完需求及其优先级后,经过考虑决定,高优先级除了第五条, 低优先级除了第三条, 其他需求都在目前版本里实现。

原因如下:

  • 首先说低优先级第三条,这块的部分调度执行 worker,目前没有想好比较优雅的实现方式,所以暂时没有实现(但是下个版本会实现)
  • 高优先级的第五条也是跟调度有点关系,如果队列里靠前的任务需要大量的 worker,那很容易造成阻塞,后面的 task 一直没办法执行,即便需要很少的 worker。所以等部分调度执行开发完再把任务按需执行打开。

Task Definition

task 表示一次任务,包含了任务执行的方法,并发数,所属的 workerSet以及执行状态等。

 1type TaskFunc func() error
 2
 3type task struct {
 4    tf          TaskFunc        // task function
 5    concurrence int             // concurrence of task
 6    ws          *workerSet      // assign value after task distribute to worker.
 7    status      TaskStatus      // store task status.
 8}
 9
10// TaskStatus is the status of task.
11type TaskStatus int

Worker Definition

worker 作为最小调度单元,仅包含 workerSeterror .

1type worker struct {
2    ws  *workerSet
3    err error
4}

TaskResult Definition

TaskResult 是一个对外暴露的 interface, 用于外部调用者获取和管理任务执行状态信息。

 1// TaskResult is a manager of submitted task.
 2type TaskResult interface {
 3    // get error if task failed.
 4    Err() error
 5    // wait for task done.
 6    Wait()
 7    // get task status.
 8    Status() TaskStatus
 9    // kill task.
10    Kill()
11}

taskTaskStatus 分别实现 TaskResult 的接口,从而外部统一拿到 TaskResult

之所以 TaskStatus 也需要实现 TaskResult 是因为部分情况下,不需要创建 task 直接返回错误状态即可。如: 提交的任务的并发数过高(超过 pool 的 size),当前 queue 已满不能再处理任何其他任务了,这种情况直接返回对应的状态码。

WorkerSet Definition

workerSet 为一组 worker的集合,作用是调度 worker 并维护起所属 task 的整个生命过程.

1// workerSet represent a group of task handle workers
2type workerSet struct {
3    task          *task
4    runningWorker atomic.Int32
5    workers       []*worker
6    ctx           context.Context
7    cancel        context.CancelFunc
8    wg            *sync.WaitGroup
9}

WorkerPool Definition

Pool 是一个可指定 size 的 worker pool. 可并发运行多个 task 并且支持额外的任务排队能力。

 1// Pool is a buffered worker pool
 2type Pool struct {
 3    // TODO: taskQueue should be a linked list, so that we can get the task from the head of the list and put it back to the head.
 4    // If we use a channel as taskQueue, we can't get the task from the head of the list and put it back to the head.
 5    // But make sure that before change it to linked list, we should have the ability run the task in min(taskQueue length, concurrence) goroutines.
 6    taskQueue         chan *task
 7    enqueuedTaskCount atomic.Int32 // count of unhandled tasks
 8    bufferSize        int          // size of taskQueue buffer, means can count of bufferSize task can wait to be handled
 9    maxWorker         int          // count of how many worker run in concurrence
10    workerSets        []*workerSet
11    lock              *sync.Mutex
12    stopFlag          atomic.Bool
13}

实现

上面已经确定需要的能力和基础的数据结构了,下面一个个去实现各个模块的能力。

Worker Implement

worker 能力相对纯粹,看看 worker 是如何工作的:

 1func (w *worker) run() {
 2    defer w.ws.done()
 3
 4    var ec = make(chan error, 1)
 5    defer close(ec)
 6    go func() {
 7        ec <- w.ws.task.tf()
 8    }()
 9
10    select {
11    case e := <-ec:
12        w.err = e
13    case <-w.ws.ctx.Done():
14    }
15}

WorkerSet Implement

workerSet 调度 worker,记录 worker 运行状态等。

点击展开
 1
 2func newWorkerSet(ctx context.Context, t *task) *workerSet {
 3    // 初始化参数
 4    // ... 省略代码
 5    return ws
 6}
 7
 8func (ws *workerSet) run() {
 9    ws.task.updateStatus(TaskStatusRunning)
10    for _, w := range ws.workers {
11        ws.addOne()
12        go w.run()
13    }
14}
15
16func (ws *workerSet) stopAll() {
17    ws.cancel()
18    ws.task.updateStatus(TaskStatusKilled)
19}
20
21// err returns the first error that occurred in the workerSet.
22func (ws *workerSet) err() error {
23    // ...省略代码
24    return nil
25}
26
27func (ws *workerSet) getRunningWorker() int {
28    return int(ws.runningWorker.Load())
29}
30
31// done called when worker stop.
32func (ws *workerSet) done() {
33    ws.addRunningWorker(-1)
34    ws.wg.Done()
35    if ws.getRunningWorker() == 0 {
36        ws.task.updateStatus(TaskStatusDone)
37    }
38}
39
40// addOne called when worker start running.
41func (ws *workerSet) addOne() {
42    ws.addRunningWorker(1)
43    ws.wg.Add(1)
44}
45
46func (ws *workerSet) wait() {
47    ws.wg.Wait()
48}

Task Implement

task 主要是记录 task 的状态,并通过 workerSet 控制其下的 worker.

点击展开
 1func newTask(tf TaskFunc, concurrence int) *task {
 2    return &task{
 3        tf:          tf,
 4        concurrence: concurrence,
 5    }
 6}
 7
 8// Err returns the first error that occurred in the workerSet.
 9func (t *task) Err() error {
10    // check t.ws if nil return nil.
11    if t.ws == nil {
12        return nil
13    }
14
15    return t.ws.err()
16}
17
18// Wait for task done.
19// Please make sure task is done or running before call this function.
20func (t *task) Wait() {
21    // check t.ws if nil.
22    if t.ws == nil {
23        return
24    }
25
26    t.ws.wait()
27}
28
29// Status returns task status.
30func (t *task) Status() TaskStatus {
31    return t.status
32}
33
34// Kill task.
35func (t *task) Kill() {
36    // check t.ws if nil.
37    if t.ws == nil {
38        return
39    }
40
41    t.ws.stopAll()
42}
43
44func (t *task) assignWorkerSet(ws *workerSet) {
45    t.ws = ws
46}
47
48func (t *task) updateStatus(status TaskStatus) {
49    t.status = status
50}

TaskStatus Implement

TaskStatus 虽然实现了 TaskResult 接口,但是不能控制任何 task,其有效的方法只有 Status()Err()

 1func (t TaskStatus) Error() string {
 2    return t.String()
 3}
 4
 5func (t TaskStatus) Err() error {
 6    switch t {
 7    case TaskStatusError, TaskStatusQueueFull, TaskStatusTooManyWorker, TaskStatusPoolClosed, TaskStatusKilled:
 8        return t
 9    }
10
11    return nil
12}
13
14func (t TaskStatus) Wait() {
15    // do nothing.
16}
17
18func (t TaskStatus) Status() TaskStatus {
19    return t
20}
21
22func (t TaskStatus) Kill() {
23    // do nothing.
24}

Pool Implement

Pool 是总的入口,任务会提交到 Pool, 并由 Pool 创建 task 并调度到 workerSet 上,同时定时清理已完成的 workerSet, 确保空闲 worker 能被合理使用。

点击展开
 1func NewPool(workerSize, queueSize int) *Pool {
 2    // ... 初始化各个参数
 3
 4    // check p.enqueue to find out why make this channel size with p.bufferSize+1.
 5    //
 6    p.taskQueue = make(chan *task, p.bufferSize+1)
 7    // 启动单独 goroutine 维护队列
 8    go p.consumeQueue()
 9    return p
10}
11
12func (p *Pool) Submit(ctx context.Context, tf TaskFunc, concurrence int) TaskResult {
13    if p.stopFlag.Load() {
14        return TaskStatusPoolClosed
15    }
16
17    if concurrence > p.maxWorker {
18        return TaskStatusTooManyWorker
19    }
20
21    // check if there has any worker place left
22    p.lock.Lock()
23    defer p.lock.Unlock()
24
25    t := newTask(tf, concurrence)
26    if p.tryRunTask(ctx, t) {
27        return t
28    }
29
30    if p.enqueueTask(t, true) {
31        return t
32    }
33
34    return TaskStatusQueueFull
35}
36
37func (p *Pool) Stop() {
38    // 关闭队列和正在运行的 workerSet
39}
40
41// tryRunTask try to put task into workerSet and run it.Return false if capacity not enough.
42// Make sure get p.Lock before call this func
43func (p *Pool) tryRunTask(ctx context.Context, t *task) bool {
44    if p.curRunningWorkerNum()+t.concurrence <= p.maxWorker {
45        ws := newWorkerSet(ctx, t)
46        p.workerSets = append(p.workerSets, ws)
47        // run 为异步方法
48        ws.run()
49        return true
50    }
51
52    return false
53}
54
55// curRunningWorkerNum get current running worker num
56// make sure lock mutex before call this func
57func (p *Pool) curRunningWorkerNum() int {
58    // ...省略代码
59    return cnt
60}
61
62// enqueueTask put task to queue.
63// p.enqueuedTaskCount increase 1 if is new task
64func (p *Pool) enqueueTask(t *task, isNewTask bool) bool {
65    // ... 省略代码
66    return true
67}
68
69func (p *Pool) consumeQueue() {
70    var ticker = time.NewTicker(time.Second)
71    for {
72        select {
73        case t, ok := <-p.taskQueue:
74            if !ok {
75                // channel closed
76                return
77            }
78            if p.tryRunTask(context.Background(), t) {
79                p.enqueuedTaskCount.Sub(1)
80                goto unlock
81            }
82            // if enqueueTask return false, means channel is closed.
83            if !p.enqueueTask(t, false) {
84                // channel is closed
85                goto unlock
86            }
87
88        unlock:
89            p.lock.Unlock()
90        case <-ticker.C:
91            log.Printf("current running worker num: %d", p.curRunningWorkerNum())
92        }
93    }
94
95    // never reach here
96}

使用

到这里相关开发基本结束了,有一些 TODO 项后面后补充完善,下面通过 test case 来看一下如何使用这个 worker pool:

 1func TestPool_SubmitOrEnqueue(t *testing.T) {
 2    p := NewPool(5, 1)
 3    var (
 4        cnt         int
 5        concurrence = 5
 6    )
 7
 8    tf := func() error {
 9        time.Sleep(time.Second)
10        log.Println("hello world")
11        cnt++
12        return nil
13    }
14
15    got := p.Submit(context.Background(), tf, concurrence)
16    if got.Status() != TaskStatusRunning {
17        t.Errorf("SubmitOrEnqueue() = %v, want %v", got.Status(), TaskStatusRunning)
18        return
19    }
20    got.Wait()
21    if cnt != concurrence {
22        t.Errorf("cnt = %v, want %v", cnt, concurrence)
23    }
24    if got := p.Submit(context.Background(), tf, concurrence); got.Status() != TaskStatusRunning {
25        t.Errorf("SubmitOrEnqueue() = %v, want %v", got, TaskStatusRunning)
26        return
27    }
28    if got := p.Submit(context.Background(), tf, concurrence); got.Status() != TaskStatusEnqueue {
29        t.Errorf("SubmitOrEnqueue() = %v, want %v", got.Status(), TaskStatusEnqueue)
30        return
31    }
32    if got := p.Submit(context.Background(), tf, 6); got.Status() != TaskStatusTooManyWorker {
33        t.Errorf("SubmitOrEnqueue() = %v, want %v", got.Status(), TaskStatusTooManyWorker)
34        return
35    }
36    if got := p.Submit(context.Background(), tf, concurrence); got.Status() != TaskStatusQueueFull {
37        t.Errorf("SubmitOrEnqueue() = %v, want %v", got.Status(), TaskStatusQueueFull)
38        return
39    }
40    p.Stop()
41    if got := p.Submit(context.Background(), tf, 1); got.Status() != TaskStatusPoolClosed {
42        t.Errorf("SubmitOrEnqueue() = %v, want %v", got.Status(), TaskStatusPoolClosed)
43        return
44    }
45}

总结

到这里这篇文章内容全部结束了,下面做一个简单的总结:

  • 介绍背景和需求
  • 根据需求定义了一组概念:task, worker, workerSet, pool
  • 各个结之间的关系以及如何实现
  • 最终给出使用的 test case.

链接 🔗

如果想仔细阅读源码,并持续关注这块功能的后续更新优化,请点击这里跳转到 GitHub.