Golang WaitGroup 底层原理及源码详解

阅读:1775

发布时间:2023年4月25日 10:03

golang
# 0 知识背景 在进入正文前,先对 `WaitGroup` 及其相关背景知识做个简单的介绍,这里主要是 `WaitGroup` 的基本使用,以及系统信号量的基础知识。对这些比较熟悉的小伙伴可以直接跳过这一节。 ## 0.1 WaitGroup `WaitGroup` 是 Golang 中最常见的并发控制技术之一,它的作用我们可以简单类比为其他语言中多线程并发控制中的 `join()`,实例代码如下: ```go package main import ( "fmt" "sync" "time" ) func main() { fmt.Println("Main starts...") var wg sync.WaitGroup // 2 指的是下面有两个协程需要等待 wg.Add(2) go waitFunc(&wg, 3) go waitFunc(&wg, 1) // 阻塞等待 wg.Wait() fmt.Println("Main ends...") } func waitFunc(wg *sync.WaitGroup, num int) { // 函数结束时告知 WaitGroup 自己已经结束 defer wg.Done() time.Sleep(time.Duration(num) * time.Second) fmt.Printf("Hello World from %v\n", num) } // 结果输出: Main starts... Hello World from 1 Hello World from 3 Main ends... ``` 如果这里没有 `WaitGroup`,主协程(main 函数)会直接跑到最后的 `Main ends...`,而没有中间两个 goroutine 的输出,加了 `WaitGroup` 后,main 就会在 `wg.Wait()` 处阻塞等待两个协程都结束后才继续执行。 上面我们看到的 `WaitGroup` 的三个方法:`Wait()`、`Add(int)` 和 `Done()` 也是 `WaitGroup` 对象仅有的三个方法。 ## 0.2 信号量(Semaphore) 信号量(Semaphore)是一种用于实现多进程或多线程之间同步和互斥的机制,也是 `WaitGroup` 中所采用的技术。并且 `WaitGroup` 自身的同步原理,也与信号量很相似。 由于翻译问题,不熟悉的小伙伴经常将信号量(Semaphore)和信号(Signal)搞混,这俩实际上是两个完全不同的东西。Semaphore 在英文中的本意是**旗语**,也就是航海领域的那个旗语,利用手旗或旗帜传递信号的沟通方式。在计算机领域,Semaphore,即信号量,在广义上也可以理解为一种进程、线程间的通信方式,但它的主要作用,正如前面所说,是用于实现进程、线程间的同步和互斥。 信号量本质上可以简单理解为一个整型数,主要包含两种操作:P(Proberen,测试)操作和 V(Verhogen,增加)操作。其中,P 操作会尝试获取一个信号量,如果信号量的值大于 0,则将信号量的值减 1 并继续执行;否则,当前进程或线程就会被阻塞,直到有其他进程或线程释放这个信号量为止。V 操作则是释放一个信号量,将信号量的值加 1。 可以把信号量看作是一种类似锁的东西,P 操作相当于获取锁,而 V 操作相当于释放锁。由于信号量是一种操作系统级别的机制,通常由内核提供支持,因此我们不用担心上述对信号量的操作本身会产生竞态条件,相信内核能搞定这种东西。 本文的重点不是信号量,因此不会过多展开关于信号量的技术细节,有兴趣的小伙伴可以查阅相关资料。 最后提一嘴技术之外的东西,Proberen 和 Verhogen 这俩单词眼生吧?因为它们是荷兰语,不是英语。为啥是荷兰语嘞?因为发明信号量的人,是上古计算机大神,来自荷兰的计算机先驱 Edsger W. Dijkstra 先生。嗯,对,就是那个 Dijkstra。 # 1 WaitGroup 底层原理 **声明:本文所用源码均基于 Go 1.20.3 版本**,不同版本 Go 的 `WaitGroup` 源码可能略有不同,但设计思想基本是一致的。 `WaitGroup` 相关源码非常短,加上注释和空行也只有 120 多行,它们全都在 `src/sync/waitgroup.go` 中。 ## 1.1 定义 先来看 `WaitGroup` 的定义,这里我把源文件中的注释都简单翻译了一下: ```go // WaitGroup 等待一组 Goroutine 完成。 // 主 Goroutine 调用 Add 方法设置要等待的 Goroutine 数量, // 然后每个 Goroutine 运行并在完成后调用 Done 方法。 // 同时,可以使用 Wait 方法阻塞,直到所有 Goroutine 完成。 // // WaitGroup 在第一次使用后不能被复制。 // // 根据 Go 内存模型的术语,Done 调用“同步于”任何它解除阻塞的 Wait 调用的返回。 type WaitGroup struct { noCopy noCopy state atomic.Uint64 // 高 32 位是计数器, 低 32 位是等待者数量(后文解释)。 sema uint32 } ``` `WaitGroup` 类型是一个结构体,它有三个私有成员,我们一个一个来看。 ### 1.1.1 noCopy 首先是 `noCopy`,这个东西是为了告诉编译器,`WaitGroup` 结构体对象不可复制,即 `wg2 := wg` 是非法的。之所以禁止复制,是为了防止可能发生的死锁。但实际上如果我们对 `WaitGroup` 对象进行复制后,至少在 1.20 版本下,Go 的编译器只是发出警告,没有阻止编译过程,我们依然可以编译成功。警告的内容如下: ``` assignment copies lock value to wg2: sync.WaitGroup contains sync.noCopy ``` 为什么编译器没有编译失败,我猜应该是 Go 官方想尽量减少编译器对程序的干预,而更多地交给程序员自己去处理(此时 Rust 发出了一阵笑声)。总之,我们在使用 `WaitGroup` 的过程中,不要去复制它就对了,不然非常容易产生死锁(其实结构体注释上也说了,WaitGroup 在第一次使用后不能被复制)。譬如我将文章开头代码中的 main 函数稍微改了改: ```go func main() { fmt.Println("Main starts...") var wg sync.WaitGroup // 2 指的是下面有两个协程需要等待 wg.Add(1) wg2 := wg wg2.Add(1) go waitFunc(&wg, 3) go waitFunc(&wg2, 1) // 阻塞等待 wg.Wait() wg2.Wait() fmt.Println("Main ends...") } // 输出结果 Main starts... Hello World from 1 Hello World from 3 fatal error: all goroutines are asleep - deadlock! goroutine 1 [semacquire]: sync.runtime_Semacquire(0xc000042060?) C:/Program Files/Go/src/runtime/sema.go:62 +0x27 sync.(*WaitGroup).Wait(0xe76b28?) C:/Program Files/Go/src/sync/waitgroup.go:116 +0x4b main.main() D:/Codes/Golang/waitgroup/main.go:23 +0x139 exit status 2 ``` 为什么会这样?因为 wg 已经 `Add(1)` 了,这时我们复制了 wg 给 wg2,并且是个浅拷贝,意味着 wg2 内实际上已经是 `Add(1)` 后的状态了(state 成员保存的状态,即它的值),此时我们再执行 `wg2.Add(1)`,其实相当于执行了两次 `wg2.Add(1)`。而后面 `waitFunc()` 中对 wg2 只进行了一次 `Done()` 释放操作,main 函数在 `wg2.Wait()` 时就陷入了无限等待,即 `all goroutines are asleep`。等看了后面 `Add()` 和 `Done()` 的原理后,再回头来看这段死锁的代码,会更加清晰。 那么这段代码能既复制,又不死锁吗?当然可以,只需要把 `wg2 := wg` 提到 `wg.Add(1)` 前面即可。 ### 1.1.2 state atomic.Uint64 `state` 是 `WaitGroup` 的核心,它是一个无符号的 64 位整型,并且用的是 `atomic` 包中的 `Uint64`,所以 `state` 本身是线程安全的。至于 `atomic.Uint64` 为什么能保证线程安全,因为它使用了 `CompareAndSwap(CAS)` 操作,而这个操作依赖于 CPU 提供的原子性指令,是 CPU 级的原子操作。 `state` 的高 32 位是计数器(counter),低 32 位是等待者数量(waiters)。其中计数器其实就是 `Add(int)` 数量的总和,譬如 `Add(1)` 后再 `Add(2)`,那么这个计数器就是 1 + 2 = 3;而等待数量就是现在有多少 goroutine 在执行 `Wait()` 等待 `WaitGroup` 被释放。 ### 1.1.3 sema uint32 这玩意儿就是信号量,它的用法我们到后文结合代码再讲。 ## 1.2 Add(delta int) 首先是 `Add(delta int)` 方法。`WaitGroup` 所有三个方法都没有返回值,并且只有 `Add` 拥有参数,整个设计可谓简洁到了极点。 `Add` 方法的第一句代码是: ```go if race.Enabled { if delta < 0 { // Synchronize decrements with Wait. race.ReleaseMerge(unsafe.Pointer(wg)) } race.Disable() defer race.Enable() } ``` `race.Enabled` 是判断当前程序是否开启了竞态条件检查,这个检查是在编译时需要我们手动指定的:`go build -race main.go`,默认情况下并不开启,即 `race.Enabled` 在默认情况下就是 `false`。这段代码里如果程序开启了竞态条件检查,会将其关闭,最后再重新打开。其他有关 `race` 的细节本文不再讨论,这对我们理解 `WaitGroup` 也没有太大影响,将其考虑进去反而会增加我们理解 `WaitGroup` 核心机制的复杂度,因此后续代码中也会忽略所有与 `race` 相关的部分。 `Add` 方法整理后的代码如下: ```go // Add 方法将 delta 值加上计数器,delta 可以为负数。如果计数器变为 0, // 则所有在 Wait 上阻塞的 Goroutine 都会被释放。 // 如果计数器变为负数,则 Add 方法会 panic。 // // 注意:当计数器为 0 时调用 delta 值为正数的 Add 方法必须在 Wait 方法之前执行。 // 而 delta 值为负数或者 delta 值为正数但计数器大于 0 时,则可以在任何时间点执行。 // 通常情况下,这意味着应该在创建 Goroutine 或其他等待事件的语句之前执行 Add 方法。 // 如果一个 WaitGroup 用于等待多组独立的事件, // 那么必须在所有先前的 Wait 调用返回之后再进行新的 Add 调用。 // 详见 WaitGroup 示例代码。 func (wg *WaitGroup) Add(delta int) { // 将 int32 的 delta 变成 unint64 后左移 32 位再与 state 累加。 // 相当于将 delta 与 state 的高 32 位累加。 state := wg.state.Add(uint64(delta) << 32) // 高 32 位,就是 counter,计数器 v := int32(state >> 32) // 低 32 位,就是 waiters,等待者数量 w := uint32(state) // 计数器为负数时直接 panic if v < 0 { panic("sync: negative WaitGroup counter") } // 当 Wait 和 Add 并发执行时,会有概率触发下面的 panic if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } // 如果计数器大于 0,或者没有任何等待者,即没有任何 goroutine 在 Wait(),那么就直接返回 if v > 0 || w == 0 { return } // 当 waiters > 0 时,这个 Goroutine 将计数器设置为 0。 // 现在不可能有对状态的并发修改: // - Add 方法不能与 Wait 方法同时执行, // - Wait 不会在看到计数器为 0 时增加等待者。 // 仍然需要进行简单的健全性检查来检测 WaitGroup 的误用情况。 if wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } // 重置 state 为 0 wg.state.Store(0) // 唤醒所有等待者 for ; w != 0; w-- { // 使用信号量控制唤醒等待者 runtime_Semrelease(&wg.sema, false, 0) } } ``` 这里我将原代码中的注释翻译成了中文,并且自己在每句代码前也都加了注释。 一开始,方法将参数 `delta` 变成 uint64 后左移 32 位,和 `state` 相加。因为 `state` 的高 32 位是这个 `WaitGroup` 的计数器,所以这里其实就是把计数器进行了累加操作: ```go state := wg.state.Add(uint64(delta) << 32) ``` 接着,程序会分别取出已经累加后的计数器 `v`,和当前的等待者数量 `w`: ```go v := int32(state >> 32) w := uint32(state) ``` 然后是几个判断: ```go // 计数器为负数时直接 panic if v < 0 { panic("sync: negative WaitGroup counter") } // 当 Wait 和 Add 并发执行时,会有概率触发下面的 panic if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } // 如果计数器大于 0,或者没有任何等待者, // 即没有任何 goroutine 在 Wait(),那么就直接返回 if v > 0 || w == 0 { return } ``` 注释已经比较清晰了,这里主要展开解释一下第二个 `if`:`if w != 0 && delta > 0 && v == int32(delta)`。 1. `w != 0` 意味着当前有 goroutine 在 `Wait()`; 2. `delta > 0` 意味着 `Add()` 传入的是正整数,也就是正常调用; 3. `v == int32(delta)` 意味着累加后的计数器等于传入的 `delta`,这里最容易想到的符合这个等式的场景是:**原计数器等于 0 时**,也就是 wg 第一次使用,或前面的 `Wait()` 已经全部结束时。 上述三个条件看上去有些冲突:`w != 0` 表示存在 `Wait()`,而 `v == int32(delta)` 按照分析应该不存在 `Wait()`。再往下分析,其实应该是 `v` 在获取的时候不存在 `Wait()`,而 `w` 在获取的时候存在 `Wait()`。会有这种可能吗?会!就是并发的时候:当前 goroutine 获取了 `v`,然后另一个 goroutine 立刻进行了 `Wait()`,接着本 goroutine 又获取了 `w`,过程如下: ![过程时序](http://image.dubingxuan.com/waitgroup/%E6%97%B6%E5%BA%8F.png) 我们可以用下面这段代码来复现这个 `panic`: ```go func main() { var wg sync.WaitGroup // 并发问题不易复现,所以循环多次 for i := 0; i < 100000; i++ { go addDoneFunc(&wg) go waitFunc(&wg) } wg.Wait() } func addDoneFunc(wg *sync.WaitGroup) { wg.Add(1) wg.Done() } func waitFunc(wg *sync.WaitGroup) { wg.Wait() } // 输出结果 panic: sync: WaitGroup misuse: Add called concurrently with Wait goroutine 71350 [running]: sync.(*WaitGroup).Add(0x0?, 0xbf8aa5?) C:/Program Files/Go/src/sync/waitgroup.go:65 +0xce main.addDoneFunc(0xc1cf66?, 0x0?) D:/Codes/Golang/waitgroup/main.go:19 +0x1e created by main.main D:/Codes/Golang/waitgroup/main.go:11 +0x8f exit status 2 ``` 这段代码可能要多运行几次才会看到上述效果,因为这种并发操作在整个 `WaitGroup` 的生命周期中会造成好几种 `panic`,包括 `Wait()` 方法中的。 因此,我们在使用 `WaitGroup` 的时候应当注意一点:**不要在被调用的 goroutine 内部使用 `Add`,而应当在外面使用**,也就是: ```go // 正确 wg.Add(1) go func(wg *sync.WaitGroup) { defer wg.Done() }(&wg) wg.Wait() // 错误 go func(wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() }(&wg) wg.Wait() ``` 从而避免并发导致的异常。 上面三个 `if` 都结束后,会再次对 `state` 的一致性进行判断,防止并发异常: ```go if wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } ``` 这里 `state.Load()` 包括后面会出现的 `Store()` 都是 `atomic.Uint64` 的原子操作。 根据前面代码的逻辑,当程序运行到这里时,计数器一定为 0,而等待者则可能 >= 0,于是代码会执行一次 `wg.state.Store(0)` 将 `state` 设为 0,接着执行通知等待者结束等待的操作: ```go wg.state.Store(0) for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) } ``` 好了,这里又是让人迷惑的地方,我第一次看到这段代码时产生了下面几个疑问: 1. 为什么 `Add` 方法会有计数器为 0 的分支逻辑?计数器不是累加的吗? 2. 为什么要在 `Add` 中通知等待者结束,不应该是 `Done` 方法吗? 3. 那个 `runtime_Semrelease(&wg.sema, false, 0)` 为什么需要循环 `w` 次? 一个一个来看。 - **为什么 `Add` 方法会有计数器为 0 的分支逻辑?** 首先,按照前面代码的逻辑,只有计数器 `v` 为 0 的时候,代码才会走到最后两句,而之所以为 0,是因为 `Add(delta int)` 的参数 `delta` 是一个 `int`,也就是说,**`delta` 可以为负数**!那什么时候会传入负数进来呢?`Done` 的时候。我们去看 `Done()` 的代码,会发现它非常简单: ```go // Done 给 WaitGroup 的计数器减 1。 func (wg *WaitGroup) Done() { wg.Add(-1) } ``` 所以,`Done` 操作或是我们手动给 `Add` 传入负数时,就会进入到 `Add` 最后几行逻辑,而 `Done` 本身也意味着当前 goroutine 的 `WaitGroup` 结束,需要同步给外部的 `Wait` 让它不再阻塞。 - **为什么要在 `Add` 中通知等待者结束,不应该是 `Done` 方法吗?** 嗯,这个问题其实在上一个问题已经一起解决了,因为 `Done()` 实际上调用了 `Add(-1)`。 - **那个 `runtime_Semrelease(&wg.sema, false, 0)` 为什么需要循环 `w` 次?** 这个函数按照字面意思,就是释放信号量。源码在 `src/sync/runtime.go` 中,函数声明如下: ```go // Semrelease 函数用于原子地增加 *s 的值, // 并在有等待 Semacquire 函数被阻塞的协程时通知它们继续执行。 // 它旨在作为同步库使用的简单唤醒基元,不应直接使用。 // 如果 handoff 参数为 true,则将 count 直接传递给第一个等待者。 // skipframes 参数表示在跟踪时要忽略的帧数,从 runtime_Semrelease 的调用者开始计数。 func runtime_Semrelease(s *uint32, handoff bool, skipframes int) ``` 第一个参数就是信号量的值本身,释放时会 +1。 第二个参数 `handoff` 在我查阅了资料后,根据我的理解,应该是:当 `handoff` 为 `false` 时,仅正常唤醒其他等待的协程,但是不会立即调度被唤醒的协程;而当 `handoff` 为 `true` 时,会立刻调度被唤醒的协程。 第三个参数 `skipframes`,看上去应当也和调度有关,但具体含义我不太确定,这里就不猜了(水平有限,见谅哈)。 按照信号量本身的机制,这里释放时会 +1,同理还存在一个信号量获取函数 `runtime_Semacquire(s *uint32)` 会在信号量 > 0 时将信号量 -1,否则等待,它会在 `Wait()` 中被调用。这也是 `runtime_Semrelease` 需要循环 `w` 次的原因:因为那 `w` 个 `Wait()` 中会调用 `runtime_Semacquire` 并不断将信号量 -1,也就是减了 `w` 次,所以两个地方需要对冲一下嘛。 信号量和 `WaitGroup` 的机制很像,但计数器又是反的,所以这里再多嘴补充几句: 信号量获取时(`runtime_Semacquire`),其实就是在阻塞等待,P(Proberen,测试)操作,如果此时信号量 > 0,则获取成功,并将信号量 -1,否则继续等待; 信号量释放时(`runtime_Semrelease`),会把信号量 +1,也就是 V(Verhogen,增加)操作。 ## 1.2 Done() `Done()` 方法我们在上面已经看到过了: ```go // Done 给 WaitGroup 的计数器减 1。 func (wg *WaitGroup) Done() { wg.Add(-1) } ``` ## 1.3 Wait() 同样的,这里我会把与 `race` 相关的代码都删掉: ```go // Wait 会阻塞,直到计数器为 0。 func (wg *WaitGroup) Wait() { for { state := wg.state.Load() v := int32(state >> 32) // 计数器 w := uint32(state) // 等待者数量 if v == 0 { // 计数器为 0,直接返回。 return } // 增加等待者数量 if wg.state.CompareAndSwap(state, state+1) { // 获取信号量 runtime_Semacquire(&wg.sema) // 这里依然是为了防止并发问题 if wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return } } } ``` 比 `Add` 简单多了,而且有了前面 `Add` 的长篇大论为基础,`Wait` 的代码看上去一目了然。 当计数器为 0,即没有任何 goroutine 调用 `Add` 时,直接调用 `Wait`,没有任何意义,因此直接返回,也不操作信号量。 最后 `Wait` 也有一个防止并发问题的判断,而这个 panic 同样可以用前面 `Add` 中的那段并发问题代码复现,大家可以试试。 `Wait` 中唯一不同的是,它用了一个无限循环 `for{}`,为什么?这是因为,`wg.state.CompareAndSwap(state, state+1)` 这个原子操作因为并发等原因有可能失败,此时就需要重新获取 `state`,把整个过程再走一遍。而一旦操作成功,`Wait` 会在 `runtime_Semacquire(&wg.sema)` 处阻塞,直到 `Done` 操作将计数器减为 0,`Add` 中释放了信号量。 # 2 结语 至此,`WaitGroup` 的源码已全部解析完毕。作为 Golang 中最重要的并发组件之一,`WaitGroup` 的源码居然只有这么寥寥百行代码,倒是给我们理解它的原理降低了不少难度。 开文之前我也没想到会写这么多东西,能看到这里的小伙伴们,感谢你们的耐心。 本人水平有限,若文中有什么纰漏或错误,还请大家不吝指出,再次感谢!