JAX value_and_grad Training Step
마지막 수정:
JAX 학습 루프의 중심은 loss.backward()가 아니라 jax.value_and_grad(loss_fn)이다.
JAX Model Shape Parameters are data. Forward and train steps are pure functions.
params, batch -> loss, grads parameter pytree
+ {"blocks": [...], "token_embedding": ...} pure forward
-> forward(params, input_ids, config) logits
[batch, seq, vocab] PyTorch module owns parameters and behavior
model(input_ids) JAX state is explicit and transformed by functions
jit(value_and_grad(loss_fn)) PyTorch의 step은 mutable object를 갱신한다.
loss.backward()
optimizer.step()
JAX lab의 step은 state를 입력으로 받고 새 state를 반환한다.
loss, grads = jax.value_and_grad(loss_fn)(params, inputs, targets, config)
params, opt_state = adamw_update(params, grads, opt_state, lr)
return params, opt_state, loss
이 방식은 더 장황하지만 중요한 장점이 있다.
train_step이 pure function에 가까움
optimizer state가 숨지 않음
jit/sharding transformation이 쉬움
checkpoint할 state boundary가 명확함
JAX에서 “학습 상태”는 보통 다음 묶음이다.
params
optimizer state
PRNG key
step counter
실습 위치
labs/jax-transformer/train.py
train_step 안에서 loss, grad, AdamW update가 어떤 순서로 pure function처럼 연결되는지 확인한다.