DDP vs FSDP Lab

마지막 수정:

pytorchdistributedddpfsdpprofiling

DDP와 FSDP는 둘 다 data parallel 축 위에서 작동하지만, 무엇을 복제하고 무엇을 shard하는지가 다르다.

DDP:
  each rank has full params + full grads + full optimizer state
  backward triggers gradient all-reduce

FSDP:
  ranks keep parameter shards
  layer compute triggers parameter all-gather
  backward/update triggers reduce-scatter style communication

실습은 같은 tiny Transformer를 두 wrapper로만 바꿔 실행한다.

cd labs/pytorch-transformer
torchrun --nproc_per_node=2 distributed_train.py --steps 20 --device cuda
torchrun --nproc_per_node=2 fsdp_train.py --steps 20 --device cuda

비교할 값은 세 가지다.

tokens/sec
peak memory per GPU
profiler에서 보이는 communication op

작은 모델에서는 FSDP가 더 느릴 수 있다. 이것은 실패가 아니다. FSDP는 parameter memory를 줄이기 위해 더 많은 all-gather/reduce-scatter를 넣는다. 모델이 작거나 interconnect가 느리면 communication overhead가 memory 이득보다 커질 수 있다.

이 실습의 목표는 “FSDP가 항상 빠르다”를 확인하는 것이 아니라, DDP와 FSDP가 서로 다른 병목을 만든다는 것을 보는 것이다.

확인

  • DDP에서 가장 먼저 profiler로 찾아야 할 collective는 무엇인가?
  • FSDP에서 DDP보다 추가로 생기는 communication은 무엇인가?
  • 작은 모델에서 FSDP가 DDP보다 느릴 수 있는 이유는 무엇인가?