JAX Single-Device Training Loop

마지막 수정:

jaxtrainingsingle-devicejit

JAX 단일 디바이스 학습 루프의 기준선은 다음이다.

split key
  -> make batch
  -> jit train_step
  -> params / opt_state update
  -> loss.block_until_ready()

train.pytrain_stepjax.jit로 감싼다.

train_step = jax.jit(train_step)

JAX dispatch는 비동기다. 그래서 timing을 볼 때는 결과가 실제로 끝났는지 기다려야 한다.

loss.block_until_ready()

이 기준선에서 먼저 확인할 것은 두 가지다.

loss가 내려가는가?
compile 이후 steady-state tokens/sec가 얼마인가?

실습

python3 labs/jax-transformer/train.py --steps 100 --log-every 10

확인

  • JAX timing에서 block_until_ready()가 필요한 이유는 무엇인가?
  • 첫 step 시간이 steady-state step 시간과 다를 수 있는 이유는 무엇인가?
  • 단일 디바이스 기준선을 잡지 않고 분산으로 넘어가면 어떤 문제가 생기는가?