前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >sync.WaitGroup深入源码理解

sync.WaitGroup深入源码理解

作者头像
公众号-利志分享
发布2022-04-25 08:58:56
3790
发布2022-04-25 08:58:56
举报
文章被收录于专栏:利志分享

关于sync.WaitGroup的使用,之前的文章也有介绍,这个文章我就不那么简单的说这个sync.WaitGroup的使用,而是讲讲它的实现原理。

首先我们还是来看一下官方示例:

代码语言:javascript
复制
func DoGroup() {
  wg := sync.WaitGroup{}
  for i := 1; i <= 2; i++ {
    go doPrintln(&wg, i, 0)
    wg.Add(1)
  }
  wg.Wait()
}

func doPrintln(wg *sync.WaitGroup, i, k int) {
  defer func() {
    wg.Done()
  }()
  fmt.Println("i:", i, "k:", k)
}

这个示例讲述了主协程等待子协程执行完doPrintln函数之后,主线程才退出,完成整个执行流程。

下面我们看下下面的代码,看看执行结果是什么?

代码语言:javascript
复制
func DoGroupC() {
  wg := sync.WaitGroup{}
  for k := 1; k <= 2; k++ {
    go DoGroup2(&wg, k)
    wg.Add(1)
  }
  wg.Wait()
  fmt.Println("DoGroupC")
}

func DoGroup2(mWg *sync.WaitGroup, k int) {
  defer func() {
    fmt.Println("DoGroup2 k", k)
    mWg.Done()
  }()
  wg := sync.WaitGroup{}
  for i := 1; i <= 2; i++ {
    go doPrintln(&wg, i, k)
    wg.Add(1)
  }
  wg.Wait()
}

func doPrintln(wg *sync.WaitGroup, i, k int) {
  defer func() {
    wg.Done()
  }()
  fmt.Println("i:", i, "k:", k)
}

上面的代码是主协程等待两个子协程执行,然后两个子协程又等待子子协程的执行,执行结果如下,但是其实i==1和i==2是随机出现的,但是肯定都是在DoGroup2 k 1或者DoGroup2 k 2之前出现的。

代码语言:javascript
复制
i: 2 k: 2
i: 2 k: 1
i: 1 k: 2
DoGroup2 k 2
i: 1 k: 1
DoGroup2 k 1
DoGroupC

看了上面的实现,我们来看下源码的实现,是什么机制来实现这个功能的呢?下面我们看看源码

关于WaitGroup

代码语言:javascript
复制
// A WaitGroup waits for a collection of goroutines to finish.
// WaitGroup等待一组goroutine完成。
// The main goroutine calls Add to set the number of
// goroutine通过调用 Add(),增加等待的goroutines的数量。
// goroutines to wait for. Then each of the goroutines
// goroutines组中的goroutine等执行完成后,会调用Done()。
// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
// 可以用wait()等待所有goroutines完成。
// A WaitGroup must not be copied after first use.
// 不要被复制使用
type WaitGroup struct {
  noCopy noCopy

  // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
  //64 字节:高32位是counter, 低32位是等待者的数量。
  // 64-bit atomic operations require 64-bit alignment, but 32-bit
  //64位原子操作需要64位对齐,但是32位系统编译器无法保证它
  // compilers do not ensure it. So we allocate 12 bytes and then use
  // 编译不能保证它,所以我们使用12字节
  // the aligned 8 bytes in them as state, and the other 4 as storage
  // 使用8字节作为状态,其余四个字节用在放信号。
  // for the sema.
  state1 [3]uint32
}

state1是一个长度为3的数组,其中包含state和semap(信号量),state是两个计数器,一个是未执行结束的groutine计数器:counter,一个是等待group-group结束的groutine数量:waiter。

WaitGroup提供了三个方法:

Add(delta int)

Done()

Wait()

首先我们看一下Add方法

代码语言:javascript
复制
func (wg *WaitGroup) Add(delta int) {
  //获取state和sema地址指针
  statep, semap := wg.state()
  if race.Enabled {
    _ = *statep // trigger nil deref early
    if delta < 0 {
      // Synchronize decrements with Wait.
      race.ReleaseMerge(unsafe.Pointer(wg))
    }
    race.Disable()
    defer race.Enable()
  }
  //把delta左移32位,累计添加到state中,即累加到counter中
  state := atomic.AddUint64(statep, uint64(delta)<<32)
  //获取counter的值
  v := int32(state >> 32)
  //获取waiter值
  w := uint32(state)
  if race.Enabled && delta > 0 && v == int32(delta) {
    // The first increment must be synchronized with Wait.
    // Need to model this as a read, because there can be
    // several concurrent wg.counter transitions from 0.
    race.Read(unsafe.Pointer(semap))
  }
  if v < 0 {
    //经过累加后counter的值变成负数,则panic
    panic("sync: negative WaitGroup counter")
  }
  if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  }
  if v > 0 || w == 0 {
    return
  }
  // This goroutine has set counter to 0 when waiters > 0.
  // Now there can't be concurrent mutations of state:
  // - Adds must not happen concurrently with Wait,
  // - Wait does not increment waiters if it sees counter == 0.
  // Still do a cheap sanity check to detect WaitGroup misuse.
  if *statep != state {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  }
  // Reset waiters count to 0.
  *statep = 0
  for ; w != 0; w-- {
    //释放信号量,执行一次,释放一个,唤醒一个等待着
    runtime_Semrelease(semap, false, 0)
  }
}

Add主要是把delta添加到state的counter里面,然后通过counter等于0的时候,根据waiter数值释放等量信号量,把等量的grouptine全部唤醒。

关于Done()方法

代码语言:javascript
复制
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
  wg.Add(-1)
}

Done方法就是把counter减1,如果后面执行到Add就会等待一个完成的groutine把等待着唤醒。

Wait()方法

代码语言:javascript
复制
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
  statep, semap := wg.state()
  if race.Enabled {
    _ = *statep // trigger nil deref early
    race.Disable()
  }
  for {
    //获取state值
    state := atomic.LoadUint64(statep)
    //获取counter的值
    v := int32(state >> 32)
    //获取waiter的值
    w := uint32(state)
    if v == 0 {
      //如果counter的值为0,则说明所有的groutine都推出了,不需要在等待,直接返回
      // Counter is 0, no need to wait.
      if race.Enabled {
        race.Enable()
        race.Acquire(unsafe.Pointer(wg))
      }
      return
    }
    // Increment waiters count.
    //使用cas(比较交换算法)累加waiter,失败则下一次循环,通过case算法保证多个groutine同事执行wait也能正常累加
    if atomic.CompareAndSwapUint64(statep, state, state+1) {
      if race.Enabled && w == 0 {
        // Wait must be synchronized with the first Add.
        // Need to model this is as a write to race with the read in Add.
        // As a consequence, can do the write only for the first waiter,
        // otherwise concurrent Waits will race with each other.
        race.Write(unsafe.Pointer(semap))
      }
      runtime_Semacquire(semap)
      if *statep != 0 {
        panic("sync: WaitGroup is reused before previous Wait has returned")
      }
      if race.Enabled {
        race.Enable()
        race.Acquire(unsafe.Pointer(wg))
      }
      return
    }
  }
}

Wait()方法主要是累加waiter,阻塞等待信号量。

WaitGroup主要是通过信号量来实现groutine的等待

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-01-03,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 利志分享 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档