ZeRO-1
마지막 수정:
ZeRO-1은 optimizer state만 DP rank들 사이에 shard한다.
Backward는 각 GPU에서 full gradient를 만든다.
각 GPU는 자기가 update할 shard의 averaged gradient만 받는다.
자기 optimizer state shard로 parameter shard만 업데이트한다.
다음 forward를 위해 updated full parameters를 복원한다.
parameter와 backward 중 생기는 local gradient는 여전히 GPU마다 full copy로 존재한다.
GPU 0: full parameters + local full gradients + optimizer state shard 0
GPU 1: full parameters + local full gradients + optimizer state shard 1
GPU 2: full parameters + local full gradients + optimizer state shard 2
핵심은 optimizer update를 모든 GPU가 똑같이 반복하지 않는다는 점이다.
GPU 0: parameter shard 0만 update
GPU 1: parameter shard 1만 update
GPU 2: parameter shard 2만 update
왜 reduce-scatter가 필요한가
각 GPU는 자기 micro-batch로 backward를 했기 때문에 local full gradient를 가진다.
GPU 0 local grad: [g0_A, g0_B, g0_C]
GPU 1 local grad: [g1_A, g1_B, g1_C]
GPU 2 local grad: [g2_A, g2_B, g2_C]
DP에서는 서로 다른 micro-batch에서 나온 gradient를 평균내야 한다. 하지만 ZeRO-1에서는 각 GPU가 자기 optimizer state shard에 해당하는 gradient shard만 있으면 된다.
GPU 0 needs reduced grad A
GPU 1 needs reduced grad B
GPU 2 needs reduced grad C
그래서 full all-reduce 대신 reduce-scatter를 쓴다.
reduce: 같은 shard 위치의 gradient를 합친다
scatter: 합쳐진 결과를 shard 담당 GPU에 남긴다
개념적 결과는 다음과 같다.
GPU 0: g0_A + g1_A + g2_A
GPU 1: g0_B + g1_B + g2_B
GPU 2: g0_C + g1_C + g2_C
중요한 점은 한 GPU에 다 모았다가 다시 뿌린다는 뜻이 아니라는 것이다. collective 구현은 여러 GPU가 동시에 통신하면서 각 shard의 reduced result가 해당 GPU에 남도록 만든다.
왜 all-gather가 필요한가
reduce-scatter 후 각 GPU는 자기 optimizer state shard와 gradient shard로 자기 parameter shard만 업데이트한다.
GPU 0: updated parameter shard A
GPU 1: updated parameter shard B
GPU 2: updated parameter shard C
그런데 ZeRO-1의 다음 forward는 full parameters를 필요로 한다. parameter는 shard된 상태로 계산하는 것이 아니라, DP처럼 각 GPU에 전체 parameter가 있어야 한다.
그래서 optimizer step 뒤에 updated parameter shard를 all-gather한다.
updated parameter shards
-> all-gather
-> every GPU gets full updated parameters
한 step 요약
1. full parameters로 forward
2. backward로 local full gradients 생성
3. gradient를 reduce-scatter해서 각 GPU가 reduced grad shard만 받음
4. 각 GPU가 자기 optimizer state shard로 parameter shard만 update
5. updated parameter shards를 all-gather해서 full parameters 복원
확인
- ZeRO-1에서 shard되는 것은 무엇인가?
- ZeRO-1에서 full all-reduce 대신 reduce-scatter를 쓰는 이유는 무엇인가?
- optimizer step 뒤에 all-gather가 필요한 이유는 무엇인가?