ZeRO
마지막 수정:
ZeRO는 Zero Redundancy Optimizer의 줄임말이다. 이름 그대로 DP에서 중복 저장되는 학습 state를 줄이는 방법이다.
2P + 2P + kP 2P + 2P + kP / Nd 2P + (2P + kP) / Nd (2P + 2P + kP) / Nd vanilla DP에서는 모든 GPU가 같은 모델 복사본을 가지고 서로 다른 micro-batch를 처리한다.
GPU 0: full parameters + full gradients + full optimizer states
GPU 1: full parameters + full gradients + full optimizer states
GPU 2: full parameters + full gradients + full optimizer states
이 구조는 빠르게 병렬 처리할 수 있지만, 메모리 관점에서는 중복이 많다.
ZeRO는 이 중복을 DP rank들 사이에 나눠 저장한다.
ZeRO-1: optimizer state를 shard
ZeRO-2: optimizer state + gradient를 shard
ZeRO-3: optimizer state + gradient + parameter를 shard
무엇이 줄어드는가
mixed precision training에서 Adam을 쓴다고 하자. 대략 다음 세 덩어리가 중요하다.
parameters = model weights
gradients = backward가 만든 update 신호
optimizer states = Adam의 fp32 weights, momentum, variance
Adam의 optimizer state는 크다. 그래서 ZeRO-1은 먼저 optimizer state부터 나눈다. ZeRO-2는 gradient까지 나누고, ZeRO-3는 parameter까지 나눈다.
Activation은 왜 대상이 아닌가
ZeRO는 DP rank 사이에 중복된 model state를 줄인다. 그런데 activation은 DP rank마다 다르다.
GPU 0 activation = micro-batch A에서 생긴 activation
GPU 1 activation = micro-batch B에서 생긴 activation
즉 activation은 같은 것을 중복 저장한 게 아니라, 서로 다른 데이터에서 생긴 다른 값이다. 그래서 ZeRO가 직접 shard하는 대상이 아니다.
이 차이를 model state와 activation으로 나눠 보면 더 분명하다.
parameters:
모든 DP rank가 같은 모델을 복제하므로 중복
optimizer states:
같은 parameter에 대한 Adam state이므로 중복
gradients:
동기화된 update를 만들기 위한 model-sized state라 shard 가능
activations:
각 DP rank의 micro-batch에서 생긴 고유한 값이라 중복 아님
ZeRO-3가 parameter를 shard하더라도, 계산할 layer에서는 parameter를 all-gather해서 각 DP rank가 자기 micro-batch forward를 수행한다.
GPU 0: layer params all-gather -> micro-batch A forward -> activation A
GPU 1: layer params all-gather -> micro-batch B forward -> activation B
따라서 activation은 여전히 각 DP rank의 local micro-batch 기준으로 생긴다.
activation memory는 activation checkpointing, sequence/context parallelism, micro-batch 조정 같은 다른 방법으로 다룬다. 앞으로 볼 Tensor Parallelism은 layer 내부 계산 자체를 나누기 때문에 일부 activation을 hidden dimension 방향으로 shard할 수 있고, Sequence/Context Parallelism은 sequence dimension 방향의 activation memory를 더 직접적으로 다룬다.
핵심 trade-off
ZeRO stage가 올라갈수록 GPU 하나가 들고 있는 model-related memory는 줄어든다.
대신 필요한 순간에 shard를 맞추는 통신이 늘어난다.
memory saving up
communication complexity up
그래서 ZeRO를 볼 때는 항상 두 질문을 같이 봐야 한다.
무엇을 shard하는가?
그 shard를 계산에 쓰기 위해 어떤 collective가 필요한가?
확인
- vanilla DP에서 중복 저장되는 세 가지 큰 학습 state는 무엇인가?
- ZeRO-1, ZeRO-2, ZeRO-3는 각각 무엇까지 shard하는가?
- activation은 왜 ZeRO의 직접 shard 대상이 아닌가?