JAX Transformer 학습하기
이 경로는 JAX로 작성한 Transformer를 실제 학습 루프로 연결합니다.
PyTorch의 loss.backward()와 optimizer.step()은 내부 state mutation을 자연스럽게 사용합니다. JAX에서는 train step이 더 함수형입니다.
params, opt_state, batch
-> loss, grads
-> new_params, new_opt_state
핵심은 update되는 모든 것을 return value로 명시하는 것입니다. 이 형태가 되어야 jax.jit, jax.value_and_grad, sharding과 잘 맞습니다.
이 path의 산출물은 loss가 내려가는 단일 디바이스 기준선입니다. 이 기준선이 있어야 나중에 sharding, pmap, pjit, multi-host 학습이 실제로 무엇을 바꿨는지 비교할 수 있습니다.
- JAX Transformer Forward Pass — RMSNorm, causal attention, MLP, decoder block, tied LM head를 JAX 함수 합성으로 구현하는 방식을 이해한다.
- JAX Next-Token Loss와 Batch — JAX에서 next-token prediction batch를 만들고 logits와 target으로 cross entropy loss를 계산한다.
- JAX Single-Device Training Step Review — 단일 디바이스 JAX training step에서 params, optimizer state, PRNG key, metrics가 어떻게 흐르는지 정리한다.
- JAX value_and_grad Training Step — PyTorch의 backward/optimizer.step 흐름을 JAX의 value_and_grad, explicit AdamW state, pure train_step으로 바꾼다.
- JAX jit Compile Boundary — JAX에서 첫 호출 compile time과 이후 execution time을 분리하고, static shape와 block_until_ready가 왜 중요한지 이해한다.
- JAX Single-Device Training Loop — jit-compiled train_step을 반복 호출하고 block_until_ready로 실제 throughput을 측정한다.
- JAX Gradient Accumulation — JAX에서 microbatch gradient를 누적해 effective batch size를 키우는 방법을 함수형 상태 흐름으로 이해한다.
- JAX Mixed Precision과 Memory — JAX에서 dtype을 명시적으로 다루며 activation, parameter, optimizer state memory를 구분한다.