JAX Distributed Training

이 경로는 JAX 단일 디바이스 학습 루프를 multi-device 관점으로 확장합니다.

PyTorch DDP는 wrapper가 model replica와 gradient all-reduce를 담당합니다. JAX에서는 먼저 array placement와 function transformation을 명시적으로 이해해야 합니다.

Mesh:
logical device axes

NamedSharding:
array placement rule

PartitionSpec:
which tensor dimension maps to which mesh axis

목표는 “JAX에서 DDP를 어떻게 쓰나”가 아니라, JAX가 분산 학습을 array program transformation으로 보는 방식을 이해하는 것입니다.

이 path는 두 단계로 나눕니다.

replicated data parallel
  -> pmap mental model
  -> mesh / sharding vocabulary
  -> pjit / PartitionSpec
  -> FSDP-style parameter sharding
  -> multi-host scaling map

목표는 JAX/TPU 계열 학습 코드에서 자주 보이는 sharding annotation을 읽을 수 있게 되는 것입니다.

  1. JAX Sharding Mesh Basics — JAX distributed training의 출발점인 Mesh, NamedSharding, PartitionSpec을 PyTorch DDP wrapper 관점과 비교한다.
  2. JAX pmap Replicated Data Parallel — pmap으로 replicated data parallel training의 기본 구조를 이해하고 PyTorch DDP와 비교한다.
  3. JAX Data Parallel Training Map — JAX에서 data parallel training step을 params replication, batch sharding, gradient aggregation 관점으로 설계한다.
  4. JAX pjit와 PartitionSpec — pjit, Mesh, PartitionSpec으로 array sharding을 명시하는 JAX distributed programming model을 익힌다.
  5. JAX FSDP-style Parameter Sharding — PyTorch FSDP/ZeRO와 같은 memory 절감 아이디어를 JAX parameter sharding 관점으로 번역한다.
  6. JAX Multi-Host Scaling Map — 단일 머신 JAX sharding에서 multi-host 학습으로 넘어갈 때 필요한 개념을 정리한다.
  7. JAX vs PyTorch Comparison Report — 같은 tiny Transformer를 PyTorch와 JAX로 구현한 뒤 모델 상태, 학습 루프, profiling, distributed 관점의 차이를 정리한다.