JAX Mixed Precision과 Memory
마지막 수정:
Mixed precision의 목적은 단순히 “더 빠르게”가 아니다. 먼저 어떤 memory가 줄어드는지 구분해야 한다.
parameters
gradients
optimizer state
activations
temporary buffers
JAX에서는 dtype이 array와 computation에 직접 드러난다.
x = x.astype(jnp.bfloat16)
하지만 모든 값을 무조건 낮은 precision으로 바꾸면 안 된다. optimizer state나 loss 계산은 더 높은 precision이 필요할 수 있다. PyTorch의 autocast가 많은 결정을 숨겨준다면, JAX에서는 어디서 cast할지 더 명시적으로 설계하는 편이다.
Frontier-scale training 관점에서는 이 명시성이 장점이 된다. sharding rule, dtype policy, optimizer state layout을 함께 설계할 수 있기 때문이다.
확인
- mixed precision이 줄일 수 있는 memory 항목은 무엇인가?
- optimizer state까지 낮은 precision으로 둘 때 생길 수 있는 위험은 무엇인가?
- JAX의 명시적 dtype 관리가 대규모 학습 시스템에서 장점이 되는 이유는 무엇인가?