ZeRO-2
마지막 수정:
ZeRO-2는 ZeRO-1에서 한 단계 더 나아가 gradient도 shard한다.
During backward bucket
Before optimizer step
ZeRO-1과 ZeRO-2의 큰 통신 흐름은 비슷하다.
backward 중: reduce-scatter gradients
optimizer update 후: all-gather updated parameters
차이는 gradient를 메모리에 어떻게 남기느냐이다.
ZeRO-1: optimizer state shard
ZeRO-2: optimizer state shard + gradient shard
ZeRO-1에서 남은 낭비
ZeRO-1에서는 optimizer state가 shard되어 있으므로, 각 GPU는 자기 parameter shard를 업데이트하는 gradient shard만 필요하다.
하지만 backward 중에는 각 GPU에서 local full gradient가 생긴다.
GPU 0 local grad: [g0_A, g0_B, g0_C]
GPU 1 local grad: [g1_A, g1_B, g1_C]
GPU 2 local grad: [g2_A, g2_B, g2_C]
ZeRO-1은 이 gradient를 reduce-scatter해서 각 GPU가 자기 shard의 reduced gradient를 얻는다.
ZeRO-2는 여기서 더 명확하게 말한다.
어차피 optimizer step에 필요한 건 gradient shard뿐이다.
그러면 full gradient를 오래 들고 있지 말고 shard만 남기자.
Gradient 생명주기
ZeRO-2에서는 backward 중 layer나 bucket의 gradient가 준비되면 바로 reduce-scatter한다.
backward layer 2
-> local grad 2 생성
-> reduce-scatter grad 2
-> 내 grad shard만 유지하고 나머지 local grad는 release
이 과정을 layer나 bucket마다 반복한다.
그래서 optimizer step 전 최종 상태는 다음과 같다.
GPU 0: reduced grad shard A + optimizer state shard A
GPU 1: reduced grad shard B + optimizer state shard B
GPU 2: reduced grad shard C + optimizer state shard C
각 GPU는 자기 shard만 업데이트한다.
GPU 0: parameter shard A update
GPU 1: parameter shard B update
GPU 2: parameter shard C update
왜 all-gather는 그대로 필요한가
ZeRO-2에서도 parameter는 forward/backward 동안 full copy로 필요하다.
optimizer step 직후에는 각 GPU가 자기 parameter shard만 업데이트한 상태다.
GPU 0: updated A
GPU 1: updated B
GPU 2: updated C
다음 forward를 하려면 모든 GPU가 full updated parameters를 가져야 한다.
all-gather updated parameter shards
-> every GPU gets [updated A, updated B, updated C]
그래서 ZeRO-2에서도 all-gather는 ZeRO-1과 동일하게 필요하다.
ZeRO-1과 ZeRO-2의 차이
ZeRO-1:
parameters full
gradients full 성격이 더 강함
optimizer states shard
ZeRO-2:
parameters full
gradients shard
optimizer states shard
통신 패턴은 크게 늘지 않지만, gradient memory가 줄어든다. 그래서 ZeRO-2는 ZeRO-1보다 메모리 효율이 좋고, ZeRO-3보다 통신 부담은 작다.
확인
- ZeRO-2에서 ZeRO-1보다 추가로 shard하는 것은 무엇인가?
- ZeRO-2에서도 reduce-scatter가 필요한 이유는 무엇인가?
- ZeRO-2에서도 optimizer update 뒤에 all-gather가 필요한 이유는 무엇인가?