waitgroup
waitgroup
WaitGroup
,由 Go 标准库提供,它的功能就是用于等待一组协程运行完毕。
package main
import (
"fmt"
"sync"
)
func main() {
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func() {
defer wg.Done()
fmt.Println(i)
}()
}
wg.Wait()
}
这是一段非常简单的代码,它的功能就是开启 10 个协程打印 0-9,并等待它们运行完毕。它的用法不再赘述,接下来我们来了解下它的基本工作原理,一点也不复杂。
结构
它的类型定义位于sync/waitgroup.go
文件中
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}
字段释义如下:
state
,表示 WaitGroup 的状态,高 32 位用于统计被等待协程的数量,低 32 位用于统计等待 wg 完成的协程数量。sema
,信号量,在sync
标准库里它几乎无处不在。
它的核心就在于Add()
和Wait()
这两个方法,基本工作原理就是信号量,Wait()
方法尝试获取信号量,Add()
方法释放信号量,来实现 M 个协程等待一组 N 个协程运行完毕。
Add
Add 方法就是增加需要等待协程的数量。
func (wg *WaitGroup) Add(delta int) {
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32)
w := uint32(state)
if v < 0 {
panic("sync: negative WaitGroup counter")
}
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
if v > 0 || w == 0 {
return
}
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
流程如下:
它首先会对
wg.state
进行移位操作,分别获取高 32 位和低 32 位,对应变量v
和w
state := wg.state.Add(uint64(delta) << 32) v := int32(state >> 32) w := uint32(state)
然后开始判断,
v
代表的是 wg 计数,w
代表的等待 wg 完成的协程数量如果
v
小于 0,直接panic
,负数没有任何意义if v < 0 { panic("sync: negative WaitGroup counter") }
w
不为 0,且delta
与v
相等,表示Wait()
方法与Add()
方法被并发地调用,这是错误的使用方式if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") }
如果
v
大于 0,或者w
等于 0,表示现在没有等待 wg 完成的协程,可以直接返回if v > 0 || w == 0 { return }
走到这一步说明
v
等于 0,且w
大于 0,即当前没有协程运行,但是有协程正在等待 wg 完成,所以就需要释放信号量,唤醒这些协程。if wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } wg.state.Store(0) for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) }
Done()
方法其实就是Add(-1)
,没有什么要讲的。
Wait
如果当前有其它协程需要等待运行完成,Wait
方法的调用会使当前协程陷入阻塞。
func (wg *WaitGroup) Wait() {
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
return
}
// Increment waiters count.
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
它的流程就是一个 for 循环
读取高 32 位和低 32 位,得到需要被等待协程的数量,和等待协程的数量,如果没有协程需要等待,就直接返回
state := wg.state.Load() v := int32(state >> 32) w := uint32(state) if v == 0 { return }
否则就通过 CAS 操作将等待协程数量加一,然后尝试获取信号量,进入阻塞等待队列
// Increment waiters count. if wg.state.CompareAndSwap(state, state+1) { runtime_Semacquire(&wg.sema) ... }
当等待协程被唤醒后(因为所有被等待的协程都运行完毕了,释放了信号量),检查
state
,如果不为 0,表示在Wait()
和Add()
又被并发的使用了if wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return
如果 CAS 没有更新成功,则继续循环
小结
最后要提醒下,在使用WaitGroup
时,Add
和Wait
不要并发的调用。