JAX Distributed Training
이 경로는 JAX 단일 디바이스 학습 루프를 multi-device 관점으로 확장합니다.
PyTorch DDP는 wrapper가 model replica와 gradient all-reduce를 담당합니다. JAX에서는 먼저 array placement와 function transformation을 명시적으로 이해해야 합니다.
Mesh:
logical device axes
NamedSharding:
array placement rule
PartitionSpec:
which tensor dimension maps to which mesh axis
목표는 “JAX에서 DDP를 어떻게 쓰나”가 아니라, JAX가 분산 학습을 array program transformation으로 보는 방식을 이해하는 것입니다.
이 path는 두 단계로 나눕니다.
replicated data parallel
-> pmap mental model
-> mesh / sharding vocabulary
-> pjit / PartitionSpec
-> FSDP-style parameter sharding
-> multi-host scaling map
목표는 JAX/TPU 계열 학습 코드에서 자주 보이는 sharding annotation을 읽을 수 있게 되는 것입니다.
- JAX Sharding Mesh Basics — JAX distributed training의 출발점인 Mesh, NamedSharding, PartitionSpec을 PyTorch DDP wrapper 관점과 비교한다.
- JAX pmap Replicated Data Parallel — pmap으로 replicated data parallel training의 기본 구조를 이해하고 PyTorch DDP와 비교한다.
- JAX Data Parallel Training Map — JAX에서 data parallel training step을 params replication, batch sharding, gradient aggregation 관점으로 설계한다.
- JAX pjit와 PartitionSpec — pjit, Mesh, PartitionSpec으로 array sharding을 명시하는 JAX distributed programming model을 익힌다.
- JAX FSDP-style Parameter Sharding — PyTorch FSDP/ZeRO와 같은 memory 절감 아이디어를 JAX parameter sharding 관점으로 번역한다.
- JAX Multi-Host Scaling Map — 단일 머신 JAX sharding에서 multi-host 학습으로 넘어갈 때 필요한 개념을 정리한다.
- JAX vs PyTorch Comparison Report — 같은 tiny Transformer를 PyTorch와 JAX로 구현한 뒤 모델 상태, 학습 루프, profiling, distributed 관점의 차이를 정리한다.