JAX pjit와 PartitionSpec
마지막 수정:
pmap이 replicated data parallel의 출발점이라면, pjit 계열 사고는 더 일반적인 sharding vocabulary다.
먼저 device mesh를 만든다.
mesh = Mesh(devices, axis_names=("data",))
그리고 array axis가 mesh axis에 어떻게 놓이는지 표현한다.
sharding = NamedSharding(mesh, PartitionSpec("data", None))
이 표현은 다음 질문에 답한다.
이 array의 어떤 축을 어떤 device mesh axis로 나눌 것인가?
PyTorch FSDP가 module parameter를 shard하는 API로 보인다면, JAX에서는 array layout 자체를 먼저 말하는 느낌에 가깝다.
실습
python3 labs/jax-transformer/sharding_demo.py
확인
PartitionSpec("data", None)에서None은 무엇을 뜻하는가?- JAX sharding vocabulary가 PyTorch wrapper API보다 낮은 수준으로 느껴지는 이유는 무엇인가?
- mesh axis 이름을 붙이는 것이 코드 읽기에 주는 장점은 무엇인가?