关于sync.WaitGroup的使用,之前的文章也有介绍,这个文章我就不那么简单的说这个sync.WaitGroup的使用,而是讲讲它的实现原理。
首先我们还是来看一下官方示例:
func DoGroup() {
wg := sync.WaitGroup{}
for i := 1; i <= 2; i++ {
go doPrintln(&wg, i, 0)
wg.Add(1)
}
wg.Wait()
}
func doPrintln(wg *sync.WaitGroup, i, k int) {
defer func() {
wg.Done()
}()
fmt.Println("i:", i, "k:", k)
}
这个示例讲述了主协程等待子协程执行完doPrintln函数之后,主线程才退出,完成整个执行流程。
下面我们看下下面的代码,看看执行结果是什么?
func DoGroupC() {
wg := sync.WaitGroup{}
for k := 1; k <= 2; k++ {
go DoGroup2(&wg, k)
wg.Add(1)
}
wg.Wait()
fmt.Println("DoGroupC")
}
func DoGroup2(mWg *sync.WaitGroup, k int) {
defer func() {
fmt.Println("DoGroup2 k", k)
mWg.Done()
}()
wg := sync.WaitGroup{}
for i := 1; i <= 2; i++ {
go doPrintln(&wg, i, k)
wg.Add(1)
}
wg.Wait()
}
func doPrintln(wg *sync.WaitGroup, i, k int) {
defer func() {
wg.Done()
}()
fmt.Println("i:", i, "k:", k)
}
上面的代码是主协程等待两个子协程执行,然后两个子协程又等待子子协程的执行,执行结果如下,但是其实i==1和i==2是随机出现的,但是肯定都是在DoGroup2 k 1或者DoGroup2 k 2之前出现的。
i: 2 k: 2
i: 2 k: 1
i: 1 k: 2
DoGroup2 k 2
i: 1 k: 1
DoGroup2 k 1
DoGroupC
看了上面的实现,我们来看下源码的实现,是什么机制来实现这个功能的呢?下面我们看看源码
关于WaitGroup
// A WaitGroup waits for a collection of goroutines to finish.
// WaitGroup等待一组goroutine完成。
// The main goroutine calls Add to set the number of
// goroutine通过调用 Add(),增加等待的goroutines的数量。
// goroutines to wait for. Then each of the goroutines
// goroutines组中的goroutine等执行完成后,会调用Done()。
// runs and calls Done when finished. At the same time,
// Wait can be used to block until all goroutines have finished.
// 可以用wait()等待所有goroutines完成。
// A WaitGroup must not be copied after first use.
// 不要被复制使用
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
//64 字节:高32位是counter, 低32位是等待者的数量。
// 64-bit atomic operations require 64-bit alignment, but 32-bit
//64位原子操作需要64位对齐,但是32位系统编译器无法保证它
// compilers do not ensure it. So we allocate 12 bytes and then use
// 编译不能保证它,所以我们使用12字节
// the aligned 8 bytes in them as state, and the other 4 as storage
// 使用8字节作为状态,其余四个字节用在放信号。
// for the sema.
state1 [3]uint32
}
state1是一个长度为3的数组,其中包含state和semap(信号量),state是两个计数器,一个是未执行结束的groutine计数器:counter,一个是等待group-group结束的groutine数量:waiter。
WaitGroup提供了三个方法:
Add(delta int)
Done()
Wait()
首先我们看一下Add方法
func (wg *WaitGroup) Add(delta int) {
//获取state和sema地址指针
statep, semap := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
if delta < 0 {
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable()
defer race.Enable()
}
//把delta左移32位,累计添加到state中,即累加到counter中
state := atomic.AddUint64(statep, uint64(delta)<<32)
//获取counter的值
v := int32(state >> 32)
//获取waiter值
w := uint32(state)
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(semap))
}
if v < 0 {
//经过累加后counter的值变成负数,则panic
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
// This goroutine has set counter to 0 when waiters > 0.
// Now there can't be concurrent mutations of state:
// - Adds must not happen concurrently with Wait,
// - Wait does not increment waiters if it sees counter == 0.
// Still do a cheap sanity check to detect WaitGroup misuse.
if *statep != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// Reset waiters count to 0.
*statep = 0
for ; w != 0; w-- {
//释放信号量,执行一次,释放一个,唤醒一个等待着
runtime_Semrelease(semap, false, 0)
}
}
Add主要是把delta添加到state的counter里面,然后通过counter等于0的时候,根据waiter数值释放等量信号量,把等量的grouptine全部唤醒。
关于Done()方法
// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
Done方法就是把counter减1,如果后面执行到Add就会等待一个完成的groutine把等待着唤醒。
Wait()方法
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
statep, semap := wg.state()
if race.Enabled {
_ = *statep // trigger nil deref early
race.Disable()
}
for {
//获取state值
state := atomic.LoadUint64(statep)
//获取counter的值
v := int32(state >> 32)
//获取waiter的值
w := uint32(state)
if v == 0 {
//如果counter的值为0,则说明所有的groutine都推出了,不需要在等待,直接返回
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
//使用cas(比较交换算法)累加waiter,失败则下一次循环,通过case算法保证多个groutine同事执行wait也能正常累加
if atomic.CompareAndSwapUint64(statep, state, state+1) {
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(semap))
}
runtime_Semacquire(semap)
if *statep != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
Wait()方法主要是累加waiter,阻塞等待信号量。
WaitGroup主要是通过信号量来实现groutine的等待