ZeRO

마지막 수정:

trainingdistributedzerosharding

ZeROZero Redundancy Optimizer의 줄임말이다. 이름 그대로 DP에서 중복 저장되는 학습 state를 줄이는 방법이다.

Parameters Gradients Optimizer States Activations: not ZeRO-sharded
Vanilla DP 2P + 2P + kP
GPU 0
full
full
full
GPU 1
full
full
full
GPU 2
full
full
full
ZeRO-1 2P + 2P + kP / Nd
GPU 0
full
full
0
GPU 1
full
full
1
GPU 2
full
full
2
ZeRO-2 2P + (2P + kP) / Nd
GPU 0
full
0
0
GPU 1
full
1
1
GPU 2
full
2
2
ZeRO-3 (2P + 2P + kP) / Nd
GPU 0
0
0
0
GPU 1
1
1
1
GPU 2
2
2
2
P는 parameter 메모리 단위, kP는 optimizer state 메모리, Nd는 DP rank 수다. 아래 stage로 갈수록 더 많은 model state가 DP rank에 나뉜다. Activation은 각 DP rank의 서로 다른 micro-batch에서 생기므로 이 표의 shard 대상이 아니다.

vanilla DP에서는 모든 GPU가 같은 모델 복사본을 가지고 서로 다른 micro-batch를 처리한다.

GPU 0: full parameters + full gradients + full optimizer states
GPU 1: full parameters + full gradients + full optimizer states
GPU 2: full parameters + full gradients + full optimizer states

이 구조는 빠르게 병렬 처리할 수 있지만, 메모리 관점에서는 중복이 많다.

ZeRO는 이 중복을 DP rank들 사이에 나눠 저장한다.

ZeRO-1: optimizer state를 shard
ZeRO-2: optimizer state + gradient를 shard
ZeRO-3: optimizer state + gradient + parameter를 shard

무엇이 줄어드는가

mixed precision training에서 Adam을 쓴다고 하자. 대략 다음 세 덩어리가 중요하다.

parameters       = model weights
gradients        = backward가 만든 update 신호
optimizer states = Adam의 fp32 weights, momentum, variance

Adam의 optimizer state는 크다. 그래서 ZeRO-1은 먼저 optimizer state부터 나눈다. ZeRO-2는 gradient까지 나누고, ZeRO-3는 parameter까지 나눈다.

Activation은 왜 대상이 아닌가

ZeRO는 DP rank 사이에 중복된 model state를 줄인다. 그런데 activation은 DP rank마다 다르다.

GPU 0 activation = micro-batch A에서 생긴 activation
GPU 1 activation = micro-batch B에서 생긴 activation

즉 activation은 같은 것을 중복 저장한 게 아니라, 서로 다른 데이터에서 생긴 다른 값이다. 그래서 ZeRO가 직접 shard하는 대상이 아니다.

이 차이를 model state와 activation으로 나눠 보면 더 분명하다.

parameters:
모든 DP rank가 같은 모델을 복제하므로 중복

optimizer states:
같은 parameter에 대한 Adam state이므로 중복

gradients:
동기화된 update를 만들기 위한 model-sized state라 shard 가능

activations:
각 DP rank의 micro-batch에서 생긴 고유한 값이라 중복 아님

ZeRO-3가 parameter를 shard하더라도, 계산할 layer에서는 parameter를 all-gather해서 각 DP rank가 자기 micro-batch forward를 수행한다.

GPU 0: layer params all-gather -> micro-batch A forward -> activation A
GPU 1: layer params all-gather -> micro-batch B forward -> activation B

따라서 activation은 여전히 각 DP rank의 local micro-batch 기준으로 생긴다.

activation memory는 activation checkpointing, sequence/context parallelism, micro-batch 조정 같은 다른 방법으로 다룬다. 앞으로 볼 Tensor Parallelism은 layer 내부 계산 자체를 나누기 때문에 일부 activation을 hidden dimension 방향으로 shard할 수 있고, Sequence/Context Parallelism은 sequence dimension 방향의 activation memory를 더 직접적으로 다룬다.

핵심 trade-off

ZeRO stage가 올라갈수록 GPU 하나가 들고 있는 model-related memory는 줄어든다.

대신 필요한 순간에 shard를 맞추는 통신이 늘어난다.

memory saving up
communication complexity up

그래서 ZeRO를 볼 때는 항상 두 질문을 같이 봐야 한다.

무엇을 shard하는가?
그 shard를 계산에 쓰기 위해 어떤 collective가 필요한가?

확인

  • vanilla DP에서 중복 저장되는 세 가지 큰 학습 state는 무엇인가?
  • ZeRO-1, ZeRO-2, ZeRO-3는 각각 무엇까지 shard하는가?
  • activation은 왜 ZeRO의 직접 shard 대상이 아닌가?