JAX Multi-Host Scaling Map
마지막 수정:
Multi-host JAX 학습은 단일 프로세스 코드에 device만 더 붙이는 문제가 아니다.
먼저 구분해야 할 축이 있다.
local devices
global devices
process index
host-to-host network
mesh axis
단일 머신에서는 jax.devices()만 봐도 충분할 수 있다. multi-host에서는 전체 job의 global device mesh를 구성하고, 각 host가 자신이 담당하는 addressable shard를 가진다.
학습 시스템 관점에서 중요한 질문은 PyTorch와 같다.
무엇을 복제하는가?
무엇을 shard하는가?
어떤 collective가 critical path에 있는가?
network topology가 어느 축의 통신을 느리게 만드는가?
JAX의 장점은 이 질문을 array layout과 mesh axis 이름으로 코드에 남길 수 있다는 점이다.
확인
- local devices와 global devices를 구분해야 하는 이유는 무엇인가?
- multi-host에서 addressable shard라는 개념이 필요한 이유는 무엇인가?
- mesh axis와 network topology를 함께 봐야 하는 이유는 무엇인가?