JAX value_and_grad Training Step

마지막 수정:

jaxtrainingvalue-and-gradoptimizer

JAX 학습 루프의 중심은 loss.backward()가 아니라 jax.value_and_grad(loss_fn)이다.

JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads
parameter pytree {"blocks": [...], "token_embedding": ...}
+
pure forward forward(params, input_ids, config)
->
logits [batch, seq, vocab]
PyTorch module owns parameters and behavior model(input_ids)
JAX state is explicit and transformed by functions jit(value_and_grad(loss_fn))
JAX code becomes easier to compile and shard when mutable training state is made explicit.

PyTorch의 step은 mutable object를 갱신한다.

loss.backward()
optimizer.step()

JAX lab의 step은 state를 입력으로 받고 새 state를 반환한다.

loss, grads = jax.value_and_grad(loss_fn)(params, inputs, targets, config)
params, opt_state = adamw_update(params, grads, opt_state, lr)
return params, opt_state, loss

이 방식은 더 장황하지만 중요한 장점이 있다.

train_step이 pure function에 가까움
optimizer state가 숨지 않음
jit/sharding transformation이 쉬움
checkpoint할 state boundary가 명확함

JAX에서 “학습 상태”는 보통 다음 묶음이다.

params
optimizer state
PRNG key
step counter

실습 위치

labs/jax-transformer/train.py

train_step 안에서 loss, grad, AdamW update가 어떤 순서로 pure function처럼 연결되는지 확인한다.