JAX Single-Device Training Loop
마지막 수정:
JAX 단일 디바이스 학습 루프의 기준선은 다음이다.
split key
-> make batch
-> jit train_step
-> params / opt_state update
-> loss.block_until_ready()
train.py는 train_step을 jax.jit로 감싼다.
train_step = jax.jit(train_step)
JAX dispatch는 비동기다. 그래서 timing을 볼 때는 결과가 실제로 끝났는지 기다려야 한다.
loss.block_until_ready()
이 기준선에서 먼저 확인할 것은 두 가지다.
loss가 내려가는가?
compile 이후 steady-state tokens/sec가 얼마인가?
실습
python3 labs/jax-transformer/train.py --steps 100 --log-every 10
확인
- JAX timing에서
block_until_ready()가 필요한 이유는 무엇인가? - 첫 step 시간이 steady-state step 시간과 다를 수 있는 이유는 무엇인가?
- 단일 디바이스 기준선을 잡지 않고 분산으로 넘어가면 어떤 문제가 생기는가?