前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Go中sync.WaitGroup处理协程同步

Go中sync.WaitGroup处理协程同步

原创
作者头像
:Darwin
发布2023-08-11 19:37:05
3020
发布2023-08-11 19:37:05
举报
文章被收录于专栏:WorkLogsWorkLogs

简介

一个 sync.WaitGroup 对象可以等待一组协程结束。它很好地解决了 goroutine 同步的问题。

通常用于以下几种场景:

  • 需要等待 goroutine 多路任务完成
  • 主 goroutine 需要等待子 goroutine
  • 顺序任务需要等待前置任务

使用方法

  • main协程通过调用 wg.Add(delta int) 设置 worker 协程的个数,然后创建 worker 协程
  • worker协程执行结束以后,都要调用 wg.Done()
  • main协程调用 wg.Wait(),直到所有 worker 协程全部执行结束后返回

使用示例

  • WaitGroup内部使用一个计数器count
  • Add方法会增加计数器的值
  • Done方法会减少计数器的值
  • Wait方法会阻塞,直到计数器的值变为0
代码语言:go
复制
// 初始化 WaitGroup
var wg sync.WaitGroup

// 告诉 WaitGroup 有 2 个 goroutine 需要等待
wg.Add(2)

// 启动第一个 goroutine
go func() {
    defer wg.Done()

    // do something
    time.Sleep(time.Second * 1)
    fmt.Println("goroutine 1 done")
}()

// 启动第二个 goroutine
go func() {  
    defer wg.Done()

    // do something
    time.Sleep(time.Second * 2)
    fmt.Println("goroutine 2 done")
}()

// 等待所有注册过的 goroutine 都执行完
wg.Wait()

// 主 goroutine 等待 wg.Wait() 完成
// go next

实现原理

  • 通过原子操作统一记录计数和等待变量。
  • 在计数操作与等待操作之间加入同步机制。
  • 使用信号量机制通知等待线程。
  • 通过可见性和竞争检测保证正确性。

具体一点:

  1. 使用一个64位的原子操作变量state来存储计数和等待线程数。高32位作为计数,低32位作为等待线程数。
  2. Add方法通过原子操作将计数调整,加入必要的同步操作保证顺序。
  3. Wait方法通过循环检测计数值,如果不为0则加1等待变量,否则返回。加等待变量表示有新的等待线程。
  4. 多次Add调用可能导致计数临界下降为0时有等待线程,这时需要额外同步检查避免错误。
  5. 32位系统需要检查变量对齐情况,可能需要交换变量存储位置保证原子方式有效。
  6. 内部使用runtime提供的信号量调用runtime_Semacquire/runtime_Semrelease来实现等待通知功能。
  7. 使用内存锁race.Enable完成可见性保证和竞争检测。

sync.WaitGroup 源码

代码语言:go
复制
package sync

import (
	"internal/race"
	"sync/atomic"
	"unsafe"
)

// WaitGroup等待一组协程完成。
// 主协程调用Add来设置
// 等待的协程。然后是每个协程
// 运行并在完成时调用Done。同时,
// Wait可以用来阻塞,直到所有的协程都完成。

// WaitGroup首次使用后不能复制。
type WaitGroup struct {
	noCopy noCopy

	// 64位值:高32位为计数器,低32位为等待计数。
    // 64位原子操作需要64位对齐,但是32位编译器只保证64位字段是32位对齐的。
	// 出于这个原因,在32位体系结构上,我们需要检查state()中state1是否对齐,并在需要时动态地“交换”字段顺序。
    state1 uint64
	state2 uint32
}

// State返回指向存储在wg.state*中的State和sema字段的指针。
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// State1是64位对齐的:不做任何事情。
		return &wg.state1, &wg.state2
	} else {
		// State1是32位对齐,但不是64位对齐:这意味着(&state1)+4是64位对齐的。
		state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
		return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
	}
}

// Add向WaitGroup计数器添加增量,增量可能为负。
// 如果计数器变为零,则释放被Wait阻塞的所有协程。
// 如果计数器为负,则添加panics。

// 请注意,当计数器为零时,具有正增量的调用必须在Wait之前发生。
// 具有负增量的调用,或者在计数器大于零时开始的具有正增量的调用,可能在任何时候发生。
// 通常,这意味着对Add的调用应该在语句创建要等待的程序或其他事件之前执行。
// 如果重用WaitGroup来等待几个独立的事件集,则必须在所有先前的wait调用返回之后发生新的Add调用。
// 参见WaitGroup示例。
func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // 提前触发nil延迟
		if delta < 0 {
			// 与Wait同步减量。
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	v := int32(state >> 32)
	w := uint32(state)
	if race.Enabled && delta > 0 && v == int32(delta) {
		// 第一个增量必须与Wait同步。
        // 需要将其建模为读取,因为可能有多个并发的wg。计数器从0转换。
		race.Read(unsafe.Pointer(semap))
	}
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// 当 waiters > 0时,这个协程将counter设置为0。
    // 状态不可能同时发生突变
    // -添加不能与等待同时发生,
    // - Wait如果看到counter == 0,则不会增加waiters。
    // 仍然要做一个便宜的完整性检查来检测WaitGroup的误用。
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 将waiters计数重置为0。 
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

// Done将WaitGroup counter减1。
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

// 等待阻塞直到WaitGroup counter为0。
func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // 提前触发nil延迟
		race.Disable()
	}
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// Increment waiters count.
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			if race.Enabled && w == 0 {
				// Wait必须与第一个Add同步。
                // 需要将其建模为写操作与Add中的读操作竞争。
                // 因此,只能给第一个 waiter 写入,
                // 否则并发等待将相互竞争。
				race.Write(unsafe.Pointer(semap))
			}
			runtime_Semacquire(semap)
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

internal/race

主要用于静态编译时的并发数据竞争检测,可以更便捷地检查并发程序是否安全。

race.Enabled表示是否开启竞争检测功能。

race.Enable()开启竞争检测。

race.Disable()关闭竞争检测。

race.Acquire()模拟对共享资源获取锁。

race.ReleaseMerge()模拟对共享资源解锁并合并锁定计数。

race.Write()模拟对共享资源的写操作。

race.Read()模拟对共享资源的读操作。

信号量 semaphore

在系统中,会给每一个进程一个信号量,代表每个进程目前的状态。未得到控制权的进程,会在特定的地方被迫停下来,等待可以继续进行的信号到来。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 简介
  • 使用方法
  • 使用示例
  • 实现原理
    • sync.WaitGroup 源码
    • internal/race 包
    • 信号量 semaphore
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档