ZeRO-3

마지막 수정:

trainingdistributedzeroparametersall-gatherreduce-scatter

ZeRO-3는 optimizer state, gradient에 더해 parameter까지 shard한다.

AllGather Params Parameter Free Forward Backward ReduceScatter Grads
params
AG0 free 0 AG1 free 1 AG2 free 2 AG1 free 1 AG0
compute
F0F1F2 B2B1B0
grads
RS2 RS1 RS0
forward all-gather

각 layer를 계산하기 직전에 해당 layer parameter를 임시로 모은다.

F2 keep

마지막 forward layer parameter는 바로 B2에서 다시 쓰므로 free하지 않는다.

backward all-gather

B1, B0는 이미 full params를 버렸으므로 다시 all-gather한다.

after B0

마지막 free는 보통 생략되어 보이지만 full params를 계속 들고 있다는 뜻은 아니다.

ZeRO-3의 all-gather 순서가 0 -> 1 -> 2 -> 1 -> 0이 되는 이유는 forward는 앞에서 뒤로, backward는 뒤에서 앞으로 진행되기 때문이다.

ZeRO-1/2에서는 forward와 backward 동안 모든 GPU가 full parameters를 가지고 있었다.

ZeRO-1/2:
parameters full
gradients shard 또는 full
optimizer states shard

ZeRO-3에서는 parameter도 DP rank에 나눠 저장한다.

GPU 0: parameter shard A
GPU 1: parameter shard B
GPU 2: parameter shard C

그래서 계산할 layer가 오면 해당 layer의 parameter를 잠깐 all-gather한다.

all-gather layer params
compute layer
free full layer params

ZeRO-1/2의 all-gather와 다르다

ZeRO-1/2의 all-gather는 optimizer update 이후 full parameters를 다시 복원해 계속 들고 있기 위한 all-gather였다.

ZeRO-1/2:
updated parameter shards
-> all-gather
-> every GPU keeps full updated parameters

ZeRO-3의 all-gather는 계산 직전에 필요한 layer parameter만 잠깐 모으는 것이다.

ZeRO-3:
parameter shards stay sharded
-> gather layer 0 only when layer 0 computes
-> free layer 0 full params

즉 이름은 둘 다 all-gather지만, 의미가 다르다.

ZeRO-1/2 all-gather = 다음 forward를 위해 full params를 복원
ZeRO-3 all-gather   = 지금 계산할 layer params를 임시 복원

왜 순서가 0 -> 1 -> 2 -> 1 -> 0인가

forward는 앞에서 뒤로 간다.

forward: layer 0 -> layer 1 -> layer 2

따라서 forward parameter all-gather도:

AG0 -> F0
AG1 -> F1
AG2 -> F2

backward는 뒤에서 앞으로 간다.

backward: layer 2 -> layer 1 -> layer 0

그런데 layer 2 parameter는 F2에서 이미 all-gather했고 바로 B2에서 다시 필요하다. 그래서 F2 직후에 free하지 않고 B2에서 재사용한다.

결과적으로 전체 all-gather 순서는:

forward side:  0 -> 1 -> 2
backward side:      1 -> 0
total:         0 -> 1 -> 2 -> 1 -> 0

backward에서 왜 parameter all-gather가 필요한가

Forward Y = XW
Backward dW = X^T dY dX = dY W^T

Parameter gradient

dW는 activation X와 upstream gradient dY로 계산된다.

needs X needs dY

Input gradient

dX는 이전 layer로 넘길 gradient이고, 여기에 parameter W가 직접 필요하다.

needs dY needs W
GPU 0 W shard 0
GPU 1 W shard 1
GPU 2 W shard 2
-> all-gather W ->
temporary full W compute dX, then free
ZeRO-3 backward all-gather의 직접 이유는 activation이 없어서가 아니라, dX = dY W^T 계산에 full layer parameter가 필요하기 때문이다.

linear layer를 보자.

Y = XW

backward에서는 두 가지를 계산한다.

dW = X^T dY
dX = dY W^T

dW는 parameter gradient다. 이것은 activation X와 upstream gradient dY로 계산된다.

하지만 이전 layer로 넘길 dX를 계산하려면 parameter W가 필요하다.

dX = dY W^T

ZeRO-3에서는 forward 후 full W를 버리고 shard만 유지한다. 그래서 backward에서 해당 layer의 dX를 계산하려면 다시 parameter를 all-gather해야 한다.

이것은 activation recomputation 때문이 아니다. activation checkpointing을 쓰면 별도의 recomputation이 추가될 수 있지만, ZeRO-3 backward all-gather의 기본 이유는 parameter가 shard되어 있고 full parameter를 계속 들고 있지 않기 때문이다.

reduce-scatter는 그대로 필요하다

ZeRO-3에서도 gradient는 shard로 남긴다.

backward layer 2 -> reduce-scatter grad 2
backward layer 1 -> reduce-scatter grad 1
backward layer 0 -> reduce-scatter grad 0

각 GPU는 자기 gradient shard와 optimizer state shard로 자기 parameter shard만 업데이트한다.

GPU 0: grad shard A + optim shard A -> param shard A update
GPU 1: grad shard B + optim shard B -> param shard B update
GPU 2: grad shard C + optim shard C -> param shard C update

하지만 update 후 full parameters를 계속 all-gather해서 들고 있지는 않는다. 다음 forward가 시작되면 layer별로 다시 on-demand all-gather한다.

한 step 요약

1. layer 0 params all-gather -> F0 -> free
2. layer 1 params all-gather -> F1 -> free
3. layer 2 params all-gather -> F2 -> keep for B2
4. B2 -> free params 2 -> reduce-scatter grad 2
5. layer 1 params all-gather -> B1 -> free -> reduce-scatter grad 1
6. layer 0 params all-gather -> B0 -> reduce-scatter grad 0
7. optimizer update keeps parameters sharded

확인

  • ZeRO-3에서 ZeRO-2보다 추가로 shard하는 것은 무엇인가?
  • ZeRO-3의 all-gather는 ZeRO-1/2의 all-gather와 어떤 점에서 다른가?
  • backward에서 dX를 계산하려면 왜 parameter W가 필요한가?