JAX pmap Replicated Data Parallel

마지막 수정:

jaxpmapdistributeddata-parallelism

JAX의 pmap은 replicated data parallel을 이해하기 좋은 출발점이다.

개념은 PyTorch DDP와 가깝다.

params는 device마다 복제
batch는 device axis로 분할
각 device가 local grads 계산
grads를 평균
같은 update 적용

JAX에서는 이 평균을 collective로 명시한다.

grads = jax.lax.pmean(grads, axis_name="data")

PyTorch DDP는 gradient all-reduce를 wrapper가 주로 처리한다. JAX pmap에서는 mapped function 안에서 어떤 값을 device 간 평균낼지 코드로 드러난다.

확인

  • replicated data parallel에서 parameter와 batch는 각각 복제/분할 중 무엇인가?
  • jax.lax.pmean은 PyTorch DDP의 어떤 동작과 대응되는가?
  • collective가 코드에 드러나는 것이 학습 시스템 이해에 주는 장점은 무엇인가?