PyTorch FSDP와 ZeRO 기본기
마지막 수정:
DDP는 단순하지만 모든 rank가 model parameter, gradient, optimizer state를 들고 있다.
rank 0: params + grads + optimizer states
rank 1: params + grads + optimizer states
rank 2: params + grads + optimizer states
rank 3: params + grads + optimizer states
모델이 커지면 이 중복이 memory bottleneck이 된다.
FSDP와 ZeRO의 핵심은 이 중복을 줄이는 것이다.
replicate everything
-> shard optimizer states
-> shard gradients
-> shard parameters
PyTorch FSDP는 module parameter를 shard하고, 계산에 필요할 때 잠깐 all-gather해서 사용한 뒤 다시 shard 상태로 돌린다.
before layer: all-gather params
compute layer
after layer: free full params / keep shards
이 카드의 목표는 FSDP 설정을 외우는 것이 아니라, DDP에서 무엇이 중복이었고 FSDP가 무엇을 shard하는지 구분하는 것이다.
확인
- DDP에서 rank마다 중복 저장되는 세 가지 큰 항목은 무엇인가?
- ZeRO-1, ZeRO-2, ZeRO-3는 각각 무엇을 shard하는가?
- FSDP가 layer 계산 직전에 parameter를 all-gather하는 이유는 무엇인가?