JAX pjit와 PartitionSpec

마지막 수정:

jaxpjitshardingpartitionspec

pmap이 replicated data parallel의 출발점이라면, pjit 계열 사고는 더 일반적인 sharding vocabulary다.

먼저 device mesh를 만든다.

mesh = Mesh(devices, axis_names=("data",))

그리고 array axis가 mesh axis에 어떻게 놓이는지 표현한다.

sharding = NamedSharding(mesh, PartitionSpec("data", None))

이 표현은 다음 질문에 답한다.

이 array의 어떤 축을 어떤 device mesh axis로 나눌 것인가?

PyTorch FSDP가 module parameter를 shard하는 API로 보인다면, JAX에서는 array layout 자체를 먼저 말하는 느낌에 가깝다.

실습

python3 labs/jax-transformer/sharding_demo.py

확인

  • PartitionSpec("data", None)에서 None은 무엇을 뜻하는가?
  • JAX sharding vocabulary가 PyTorch wrapper API보다 낮은 수준으로 느껴지는 이유는 무엇인가?
  • mesh axis 이름을 붙이는 것이 코드 읽기에 주는 장점은 무엇인가?