ZeRO-1

마지막 수정:

trainingdistributedzeroreduce-scatterall-gather

ZeRO-1은 optimizer state만 DP rank들 사이에 shard한다.

Forward Backward ReduceScatter Grad AllGather Params
compute
F0F1F2 B2B1B0 F0
network
RS2RS1RS0 AllGather updated params
1. local full gradients

Backward는 각 GPU에서 full gradient를 만든다.

2. reduce-scatter

각 GPU는 자기가 update할 shard의 averaged gradient만 받는다.

3. local shard update

자기 optimizer state shard로 parameter shard만 업데이트한다.

4. all-gather

다음 forward를 위해 updated full parameters를 복원한다.

ZeRO-1은 gradient 통신 결과를 shard로 남기고, optimizer update 뒤에는 parameter shard를 다시 all-gather한다.

parameter와 backward 중 생기는 local gradient는 여전히 GPU마다 full copy로 존재한다.

GPU 0: full parameters + local full gradients + optimizer state shard 0
GPU 1: full parameters + local full gradients + optimizer state shard 1
GPU 2: full parameters + local full gradients + optimizer state shard 2

핵심은 optimizer update를 모든 GPU가 똑같이 반복하지 않는다는 점이다.

GPU 0: parameter shard 0만 update
GPU 1: parameter shard 1만 update
GPU 2: parameter shard 2만 update

왜 reduce-scatter가 필요한가

각 GPU는 자기 micro-batch로 backward를 했기 때문에 local full gradient를 가진다.

GPU 0 local grad: [g0_A, g0_B, g0_C]
GPU 1 local grad: [g1_A, g1_B, g1_C]
GPU 2 local grad: [g2_A, g2_B, g2_C]

DP에서는 서로 다른 micro-batch에서 나온 gradient를 평균내야 한다. 하지만 ZeRO-1에서는 각 GPU가 자기 optimizer state shard에 해당하는 gradient shard만 있으면 된다.

GPU 0 needs reduced grad A
GPU 1 needs reduced grad B
GPU 2 needs reduced grad C

그래서 full all-reduce 대신 reduce-scatter를 쓴다.

reduce:  같은 shard 위치의 gradient를 합친다
scatter: 합쳐진 결과를 shard 담당 GPU에 남긴다

개념적 결과는 다음과 같다.

GPU 0: g0_A + g1_A + g2_A
GPU 1: g0_B + g1_B + g2_B
GPU 2: g0_C + g1_C + g2_C

중요한 점은 한 GPU에 다 모았다가 다시 뿌린다는 뜻이 아니라는 것이다. collective 구현은 여러 GPU가 동시에 통신하면서 각 shard의 reduced result가 해당 GPU에 남도록 만든다.

왜 all-gather가 필요한가

reduce-scatter 후 각 GPU는 자기 optimizer state shard와 gradient shard로 자기 parameter shard만 업데이트한다.

GPU 0: updated parameter shard A
GPU 1: updated parameter shard B
GPU 2: updated parameter shard C

그런데 ZeRO-1의 다음 forward는 full parameters를 필요로 한다. parameter는 shard된 상태로 계산하는 것이 아니라, DP처럼 각 GPU에 전체 parameter가 있어야 한다.

그래서 optimizer step 뒤에 updated parameter shard를 all-gather한다.

updated parameter shards
-> all-gather
-> every GPU gets full updated parameters

한 step 요약

1. full parameters로 forward
2. backward로 local full gradients 생성
3. gradient를 reduce-scatter해서 각 GPU가 reduced grad shard만 받음
4. 각 GPU가 자기 optimizer state shard로 parameter shard만 update
5. updated parameter shards를 all-gather해서 full parameters 복원

확인

  • ZeRO-1에서 shard되는 것은 무엇인가?
  • ZeRO-1에서 full all-reduce 대신 reduce-scatter를 쓰는 이유는 무엇인가?
  • optimizer step 뒤에 all-gather가 필요한 이유는 무엇인가?