前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【Go】sync.WaitGroup 源码阅读

【Go】sync.WaitGroup 源码阅读

作者头像
JuneBao
发布2022-10-26 15:12:52
2310
发布2022-10-26 15:12:52
举报
文章被收录于专栏:JuneBaoJuneBao

没想到人家巧妙利用了数组连续内存和 int 精度丢失来存储和读取状态,我大意了,没有闪 ┗|`O′|┛ 嗷~~

WaitGroup

sync.WaitGroup 用于等待一组 goroutine 返回,如:

代码语言:javascript
复制
var wg = sync.WaitGroup{}

func do() {
    time.Sleep(time.Second)
    fmt.Println("done")
    wg.Done()
}

func main() {
    go do()
    go do()
    wg.Add(2)
    wg.Wait()
    fmt.Println("main done")
}

概览

如上面的例子, WaitGroup 只堆外暴露了三个方法:

代码语言:javascript
复制
// 等待的 goroutine 数加 delta
func (wg *WaitGroup) Add(delta int) 
// 等待的 goroutine 数减一
func (wg *WaitGroup) Done() 
// 阻塞,等待这一组 goroutine 全部退出
func (wg *WaitGroup) Wait()
代码语言:javascript
复制
type WaitGroup struct {
    noCopy noCopy
    state1 [3]uint32
}

WaitGroup 结构体中也只有两个字段:

  • noCopy: 用来保证不会被开发者错误拷贝
  • state1: 用来保存相关状态量

另外,他还提供了一个私有的方法用来获取状态和信号量

代码语言:javascript
复制
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

statep 就是状态量,注意这里通过 unsafe 将 3 位数组(共 96 位)强转成了 uint64 这会导致部分数据丢失,具体来说,在64位的机器上会丢失最低 32 位,也即 state1[2] 在 32 位机器上会丢失最高 32 位,也即 state1[0], 这也是 64 位和 32 位机器上数组三位元素表示意义不同的原因。

强转之后,以 64 位机器为例,数组第二位会作为 statep 的高 32 位,第一位会作为 statep 的低 32 位,也就是说,此时 statep 的结构如下:

代码语言:javascript
复制
+----------------------+-----------------------+
|                      |                       |
|      Counter         |       Waiter          |
|                      |                       |
+----------------------+-----------------------+

Add

代码语言:javascript
复制
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

Done 其实就是对 Add 的一个封装。

代码语言:javascript
复制
func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    // 把 delta 加到 count 中
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    // 获取 count
    v := int32(state >> 32)
    // 丢失高 32 位的 Counter, 得到 Waiter
    w := uint32(state)

    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    
    // Waiter 不等于 0 说明现在还有 goroutine 没有 done, 这时是不允许 Add 的
    // 也即在 Wait 的过程中不允许通过 Add 添加 
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 正常修改 Counter 后返回
    if v > 0 || w == 0 {
        return
    }
    
    // 到这说明 Counter == 0 并且 delta 不是一个正数(执行 Done,并且是最后一次 Done)
    
    // 状态改变,说明有人在 Wait 过程中 Add 了
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 状态置 0
    *statep = 0
    // 唤醒 Wait 中的 goroutine
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
    }
}

总结一下,首先 Done 只是对 Add 的简单封装,在 Add 时,通过巧妙利用精度丢失和位移运算分别计算出 add 后的 Counter 和 Waiter, 前者表示已经 add 了多少 Goroutine, 后者表示还有多少个 goroutine 需要 Wait, 这里需要注意,在 Wait 的过程中是不允许 Add 新 goroutine 的;在执行 Done 时,只是简单的将 Counter 减 1,直到 Counter == 1 时,也即最后一个 goroutine 已经执行完毕时,Done 会通知 Wait 停止阻塞,并将标志清空。

Wait

代码语言:javascript
复制
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        // Counter == 0, 没有 Add, 直接返回
        if v == 0 {
            return
        }
        // 每一次 CAS 让 Waiter 加一,并进入阻塞,等待最后一个 Done 的 goroutine 将其唤醒
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            runtime_Semacquire(semap)
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
        // 如果 CAS 比较没通过,说明在此过程中有 goroutine Done 了,需要重新去获取最新的状态
    }
}

总结

WaitGroup 用于阻塞某个 Goroutine 以等待一组 goroutine 返回,在实现上,它采用一个长度为 3 的 32 位无符号整型数组保存 Waiter, Counter, 和信号量,每次 Add 时,会将 Counder 加上 delta,而当执行 Done 或 delta 为负数时,如果 Done 的是最后一个 Goroutine, Add 会去唤醒 Wait

执行 Wait 只是将 Waiter 加一并阻塞等待 Add 的唤醒,所以其实 Waiter 的值只会是 0 或 1.

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-11-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • WaitGroup
    • 概览
      • Add
        • Wait
          • 总结
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档