JAX Data Parallel Training Map
마지막 수정:
JAX data parallelism도 본질은 PyTorch DDP와 같다.
JAX Sharding Vocabulary Describe how arrays live on devices before scaling the train step.
Mesh -> NamedSharding -> PartitionSpec data:0device 0
batch rows 0..ndata:1device 1
batch rows n..2ndata:2device 2
batch rows 2n..3ndata:3device 3
batch rows 3n..4n Array
+ [batch, seq] PartitionSpec
-> P("data", None) Meaning
shard batch, replicate seq 각 device가 batch shard를 처리한다
params는 replica로 유지한다
gradient를 data axis에 대해 평균낸다
같은 update를 모든 replica에 적용한다
차이는 표현 방식이다. PyTorch는 DDP wrapper가 이 과정을 담당한다. JAX에서는 pmap, pjit, shard_map 같은 transformation과 sharding rule로 표현한다.
학습 관점의 mental model은 다음이다.
global batch [B, T]
-> shard batch axis across devices
-> local loss/grad
-> psum/pmean over data axis
-> replicated updated params
JAX에서 먼저 익혀야 할 질문은 이것이다.
이 array는 replicated인가 sharded인가?
이 function은 어떤 axis name 위에서 collective를 쓰는가?
이 shape change가 recompilation을 유발하는가?
multi-device 실습은 sharding_demo.py 다음 단계에서 추가한다. 먼저 placement를 이해한 뒤 train step을 shard하는 것이 맞다.