JAX pmap Replicated Data Parallel
마지막 수정:
JAX의 pmap은 replicated data parallel을 이해하기 좋은 출발점이다.
개념은 PyTorch DDP와 가깝다.
params는 device마다 복제
batch는 device axis로 분할
각 device가 local grads 계산
grads를 평균
같은 update 적용
JAX에서는 이 평균을 collective로 명시한다.
grads = jax.lax.pmean(grads, axis_name="data")
PyTorch DDP는 gradient all-reduce를 wrapper가 주로 처리한다. JAX pmap에서는 mapped function 안에서 어떤 값을 device 간 평균낼지 코드로 드러난다.
확인
- replicated data parallel에서 parameter와 batch는 각각 복제/분할 중 무엇인가?
jax.lax.pmean은 PyTorch DDP의 어떤 동작과 대응되는가?- collective가 코드에 드러나는 것이 학습 시스템 이해에 주는 장점은 무엇인가?