DDP vs FSDP Lab
마지막 수정:
DDP와 FSDP는 둘 다 data parallel 축 위에서 작동하지만, 무엇을 복제하고 무엇을 shard하는지가 다르다.
DDP:
each rank has full params + full grads + full optimizer state
backward triggers gradient all-reduce
FSDP:
ranks keep parameter shards
layer compute triggers parameter all-gather
backward/update triggers reduce-scatter style communication
실습은 같은 tiny Transformer를 두 wrapper로만 바꿔 실행한다.
cd labs/pytorch-transformer
torchrun --nproc_per_node=2 distributed_train.py --steps 20 --device cuda
torchrun --nproc_per_node=2 fsdp_train.py --steps 20 --device cuda
비교할 값은 세 가지다.
tokens/sec
peak memory per GPU
profiler에서 보이는 communication op
작은 모델에서는 FSDP가 더 느릴 수 있다. 이것은 실패가 아니다. FSDP는 parameter memory를 줄이기 위해 더 많은 all-gather/reduce-scatter를 넣는다. 모델이 작거나 interconnect가 느리면 communication overhead가 memory 이득보다 커질 수 있다.
이 실습의 목표는 “FSDP가 항상 빠르다”를 확인하는 것이 아니라, DDP와 FSDP가 서로 다른 병목을 만든다는 것을 보는 것이다.
확인
- DDP에서 가장 먼저 profiler로 찾아야 할 collective는 무엇인가?
- FSDP에서 DDP보다 추가로 생기는 communication은 무엇인가?
- 작은 모델에서 FSDP가 DDP보다 느릴 수 있는 이유는 무엇인가?