JAX Single-Device Training Step Review

마지막 수정:

jaxtrainingsingle-deviceoptimizer

JAX training step은 상태를 내부에서 바꾸지 않는다.

(params, opt_state, batch)
  -> loss / grads
  -> new_params / new_opt_state / metrics

PyTorch에서는 다음 흐름이 익숙하다.

loss.backward()
optimizer.step()
optimizer.zero_grad()

JAX에서는 같은 일을 반환값으로 표현한다.

loss, grads = jax.value_and_grad(loss_fn)(params, inputs, targets, config)
params, opt_state = adamw_update(params, grads, opt_state, lr)

이 차이는 취향 문제가 아니다. JAX가 function을 trace하고 compile하고 shard하려면 어떤 값이 들어오고 나가는지 명확해야 한다.

확인

  • JAX training step에서 update되는 값은 무엇을 return해야 하는가?
  • optimizer.step() 같은 hidden mutation이 줄어들면 compile/sharding에 어떤 장점이 있는가?
  • metrics는 왜 Python print가 아니라 step의 반환값으로 빼는 편이 좋은가?