JAX Data Parallel Training Map

마지막 수정:

jaxdistributeddata-parallelismpjitpmap

JAX data parallelism도 본질은 PyTorch DDP와 같다.

JAX Sharding Vocabulary Describe how arrays live on devices before scaling the train step.
Mesh -> NamedSharding -> PartitionSpec
data:0device 0batch rows 0..n
data:1device 1batch rows n..2n
data:2device 2batch rows 2n..3n
data:3device 3batch rows 3n..4n
Array [batch, seq]
+
PartitionSpec P("data", None)
->
Meaning shard batch, replicate seq
JAX distributed work starts by making array placement explicit instead of hiding it inside a wrapper.
각 device가 batch shard를 처리한다
params는 replica로 유지한다
gradient를 data axis에 대해 평균낸다
같은 update를 모든 replica에 적용한다

차이는 표현 방식이다. PyTorch는 DDP wrapper가 이 과정을 담당한다. JAX에서는 pmap, pjit, shard_map 같은 transformation과 sharding rule로 표현한다.

학습 관점의 mental model은 다음이다.

global batch [B, T]
  -> shard batch axis across devices
  -> local loss/grad
  -> psum/pmean over data axis
  -> replicated updated params

JAX에서 먼저 익혀야 할 질문은 이것이다.

이 array는 replicated인가 sharded인가?
이 function은 어떤 axis name 위에서 collective를 쓰는가?
이 shape change가 recompilation을 유발하는가?

multi-device 실습은 sharding_demo.py 다음 단계에서 추가한다. 먼저 placement를 이해한 뒤 train step을 shard하는 것이 맞다.