JAX FSDP-style Parameter Sharding

마지막 수정:

jaxfsdpzeroparameter-sharding

FSDP와 ZeRO의 핵심은 “모든 device가 모든 state를 완전히 들고 있지 않게 한다”는 것이다.

나눌 수 있는 state는 세 가지다.

parameters
gradients
optimizer state

PyTorch FSDP는 module 단위 API로 이 일을 숨겨준다. JAX에서는 pytree leaf별 sharding rule과 optimizer state layout을 더 직접적으로 설계하는 경우가 많다.

params pytree
  leaf -> sharding rule
grads pytree
  same structure
optimizer state pytree
  same or larger structure

이 관점이 중요한 이유는 optimizer state가 parameter보다 클 수 있기 때문이다. Adam 계열은 보통 m, v를 추가로 들고 있으므로, parameter만이 아니라 optimizer state sharding까지 같이 봐야 한다.

확인

  • FSDP/ZeRO가 줄이려는 세 가지 state는 무엇인가?
  • JAX parameter pytree와 optimizer state pytree가 같은 구조를 갖는 것이 왜 유용한가?
  • parameter만 shard하고 optimizer state를 shard하지 않으면 어떤 memory 문제가 남는가?