ZeRO-3
마지막 수정:
ZeRO-3는 optimizer state, gradient에 더해 parameter까지 shard한다.
각 layer를 계산하기 직전에 해당 layer parameter를 임시로 모은다.
마지막 forward layer parameter는 바로 B2에서 다시 쓰므로 free하지 않는다.
B1, B0는 이미 full params를 버렸으므로 다시 all-gather한다.
마지막 free는 보통 생략되어 보이지만 full params를 계속 들고 있다는 뜻은 아니다.
ZeRO-1/2에서는 forward와 backward 동안 모든 GPU가 full parameters를 가지고 있었다.
ZeRO-1/2:
parameters full
gradients shard 또는 full
optimizer states shard
ZeRO-3에서는 parameter도 DP rank에 나눠 저장한다.
GPU 0: parameter shard A
GPU 1: parameter shard B
GPU 2: parameter shard C
그래서 계산할 layer가 오면 해당 layer의 parameter를 잠깐 all-gather한다.
all-gather layer params
compute layer
free full layer params
ZeRO-1/2의 all-gather와 다르다
ZeRO-1/2의 all-gather는 optimizer update 이후 full parameters를 다시 복원해 계속 들고 있기 위한 all-gather였다.
ZeRO-1/2:
updated parameter shards
-> all-gather
-> every GPU keeps full updated parameters
ZeRO-3의 all-gather는 계산 직전에 필요한 layer parameter만 잠깐 모으는 것이다.
ZeRO-3:
parameter shards stay sharded
-> gather layer 0 only when layer 0 computes
-> free layer 0 full params
즉 이름은 둘 다 all-gather지만, 의미가 다르다.
ZeRO-1/2 all-gather = 다음 forward를 위해 full params를 복원
ZeRO-3 all-gather = 지금 계산할 layer params를 임시 복원
왜 순서가 0 -> 1 -> 2 -> 1 -> 0인가
forward는 앞에서 뒤로 간다.
forward: layer 0 -> layer 1 -> layer 2
따라서 forward parameter all-gather도:
AG0 -> F0
AG1 -> F1
AG2 -> F2
backward는 뒤에서 앞으로 간다.
backward: layer 2 -> layer 1 -> layer 0
그런데 layer 2 parameter는 F2에서 이미 all-gather했고 바로 B2에서 다시 필요하다. 그래서 F2 직후에 free하지 않고 B2에서 재사용한다.
결과적으로 전체 all-gather 순서는:
forward side: 0 -> 1 -> 2
backward side: 1 -> 0
total: 0 -> 1 -> 2 -> 1 -> 0
backward에서 왜 parameter all-gather가 필요한가
Y = XW dW = X^T dY dX = dY W^T Parameter gradient
dW는 activation X와 upstream gradient dY로 계산된다.
Input gradient
dX는 이전 layer로 넘길 gradient이고, 여기에 parameter W가 직접 필요하다.
dX = dY W^T 계산에 full layer parameter가 필요하기 때문이다.linear layer를 보자.
Y = XW
backward에서는 두 가지를 계산한다.
dW = X^T dY
dX = dY W^T
dW는 parameter gradient다. 이것은 activation X와 upstream gradient dY로 계산된다.
하지만 이전 layer로 넘길 dX를 계산하려면 parameter W가 필요하다.
dX = dY W^T
ZeRO-3에서는 forward 후 full W를 버리고 shard만 유지한다. 그래서 backward에서 해당 layer의 dX를 계산하려면 다시 parameter를 all-gather해야 한다.
이것은 activation recomputation 때문이 아니다. activation checkpointing을 쓰면 별도의 recomputation이 추가될 수 있지만, ZeRO-3 backward all-gather의 기본 이유는 parameter가 shard되어 있고 full parameter를 계속 들고 있지 않기 때문이다.
reduce-scatter는 그대로 필요하다
ZeRO-3에서도 gradient는 shard로 남긴다.
backward layer 2 -> reduce-scatter grad 2
backward layer 1 -> reduce-scatter grad 1
backward layer 0 -> reduce-scatter grad 0
각 GPU는 자기 gradient shard와 optimizer state shard로 자기 parameter shard만 업데이트한다.
GPU 0: grad shard A + optim shard A -> param shard A update
GPU 1: grad shard B + optim shard B -> param shard B update
GPU 2: grad shard C + optim shard C -> param shard C update
하지만 update 후 full parameters를 계속 all-gather해서 들고 있지는 않는다. 다음 forward가 시작되면 layer별로 다시 on-demand all-gather한다.
한 step 요약
1. layer 0 params all-gather -> F0 -> free
2. layer 1 params all-gather -> F1 -> free
3. layer 2 params all-gather -> F2 -> keep for B2
4. B2 -> free params 2 -> reduce-scatter grad 2
5. layer 1 params all-gather -> B1 -> free -> reduce-scatter grad 1
6. layer 0 params all-gather -> B0 -> reduce-scatter grad 0
7. optimizer update keeps parameters sharded
확인
- ZeRO-3에서 ZeRO-2보다 추가로 shard하는 것은 무엇인가?
- ZeRO-3의 all-gather는 ZeRO-1/2의 all-gather와 어떤 점에서 다른가?
- backward에서
dX를 계산하려면 왜 parameterW가 필요한가?