JAX Sharding Mesh Basics
마지막 수정:
JAX distributed training은 먼저 array placement vocabulary를 익혀야 한다.
JAX Sharding Vocabulary Describe how arrays live on devices before scaling the train step.
Mesh -> NamedSharding -> PartitionSpec data:0device 0
batch rows 0..ndata:1device 1
batch rows n..2ndata:2device 2
batch rows 2n..3ndata:3device 3
batch rows 3n..4n Array
+ [batch, seq] PartitionSpec
-> P("data", None) Meaning
shard batch, replicate seq PyTorch DDP에서는 보통 model wrapper가 replica와 gradient all-reduce를 숨겨준다.
model = DistributedDataParallel(model)
JAX에서는 array가 어떤 device mesh에 어떻게 놓이는지를 더 직접적으로 표현한다.
mesh = Mesh(devices, axis_names=("data",))
sharding = NamedSharding(mesh, P("data", None))
x = jax.device_put(x, sharding)
P("data", None)은 첫 번째 tensor dimension을 data axis로 shard하고, 두 번째 dimension은 복제/유지한다는 뜻이다.
실습 위치
python3 labs/jax-transformer/sharding_demo.py
CPU 1개만 있어도 vocabulary는 확인할 수 있다. GPU/TPU 여러 개가 있으면 addressable_shards 출력이 실제 shard placement를 보여준다.