JAX Sharding Mesh Basics

마지막 수정:

jaxdistributedshardingmesh

JAX distributed training은 먼저 array placement vocabulary를 익혀야 한다.

JAX Sharding Vocabulary Describe how arrays live on devices before scaling the train step.
Mesh -> NamedSharding -> PartitionSpec
data:0device 0batch rows 0..n
data:1device 1batch rows n..2n
data:2device 2batch rows 2n..3n
data:3device 3batch rows 3n..4n
Array [batch, seq]
+
PartitionSpec P("data", None)
->
Meaning shard batch, replicate seq
JAX distributed work starts by making array placement explicit instead of hiding it inside a wrapper.

PyTorch DDP에서는 보통 model wrapper가 replica와 gradient all-reduce를 숨겨준다.

model = DistributedDataParallel(model)

JAX에서는 array가 어떤 device mesh에 어떻게 놓이는지를 더 직접적으로 표현한다.

mesh = Mesh(devices, axis_names=("data",))
sharding = NamedSharding(mesh, P("data", None))
x = jax.device_put(x, sharding)

P("data", None)은 첫 번째 tensor dimension을 data axis로 shard하고, 두 번째 dimension은 복제/유지한다는 뜻이다.

실습 위치

python3 labs/jax-transformer/sharding_demo.py

CPU 1개만 있어도 vocabulary는 확인할 수 있다. GPU/TPU 여러 개가 있으면 addressable_shards 출력이 실제 shard placement를 보여준다.