JAX FSDP-style Parameter Sharding
마지막 수정:
FSDP와 ZeRO의 핵심은 “모든 device가 모든 state를 완전히 들고 있지 않게 한다”는 것이다.
나눌 수 있는 state는 세 가지다.
parameters
gradients
optimizer state
PyTorch FSDP는 module 단위 API로 이 일을 숨겨준다. JAX에서는 pytree leaf별 sharding rule과 optimizer state layout을 더 직접적으로 설계하는 경우가 많다.
params pytree
leaf -> sharding rule
grads pytree
same structure
optimizer state pytree
same or larger structure
이 관점이 중요한 이유는 optimizer state가 parameter보다 클 수 있기 때문이다. Adam 계열은 보통 m, v를 추가로 들고 있으므로, parameter만이 아니라 optimizer state sharding까지 같이 봐야 한다.
확인
- FSDP/ZeRO가 줄이려는 세 가지 state는 무엇인가?
- JAX parameter pytree와 optimizer state pytree가 같은 구조를 갖는 것이 왜 유용한가?
- parameter만 shard하고 optimizer state를 shard하지 않으면 어떤 memory 문제가 남는가?