PyTorch FSDP와 ZeRO 기본기

마지막 수정:

pytorchdistributedfsdpzerosharding

DDP는 단순하지만 모든 rank가 model parameter, gradient, optimizer state를 들고 있다.

rank 0: params + grads + optimizer states
rank 1: params + grads + optimizer states
rank 2: params + grads + optimizer states
rank 3: params + grads + optimizer states

모델이 커지면 이 중복이 memory bottleneck이 된다.

FSDP와 ZeRO의 핵심은 이 중복을 줄이는 것이다.

replicate everything
  -> shard optimizer states
  -> shard gradients
  -> shard parameters

PyTorch FSDP는 module parameter를 shard하고, 계산에 필요할 때 잠깐 all-gather해서 사용한 뒤 다시 shard 상태로 돌린다.

before layer: all-gather params
compute layer
after layer: free full params / keep shards

이 카드의 목표는 FSDP 설정을 외우는 것이 아니라, DDP에서 무엇이 중복이었고 FSDP가 무엇을 shard하는지 구분하는 것이다.

확인

  • DDP에서 rank마다 중복 저장되는 세 가지 큰 항목은 무엇인가?
  • ZeRO-1, ZeRO-2, ZeRO-3는 각각 무엇을 shard하는가?
  • FSDP가 layer 계산 직전에 parameter를 all-gather하는 이유는 무엇인가?