JAX Next-Token Loss와 Batch

마지막 수정:

jaxtransformerlossdataset

학습 루프의 첫 기준선은 next-token prediction이다.

input_ids: [B, T]
targets:   [B, T]
logits:    [B, T, vocab]

JAX lab은 외부 dataset 없이 deterministic modular sequence를 만든다. 목적은 데이터 품질이 아니라 학습 루프가 실제로 parameter를 update하는지 확인하는 것이다.

inputs, targets = make_modular_batch(key, batch_size, seq_len, vocab_size)

loss는 log-softmax 뒤 target 위치를 gather해서 평균낸다.

log_probs = jax.nn.log_softmax(logits, axis=-1)
token_losses = -jnp.take_along_axis(log_probs, targets[..., None], axis=-1).squeeze(-1)
loss = jnp.mean(token_losses)

PyTorch와 달리 batch 생성도 random key를 명시적으로 받는다. 같은 key를 넣으면 같은 batch가 나온다. 새 batch가 필요하면 key를 split해야 한다.

확인

  • next-token loss에서 logits의 마지막 축은 무엇인가?
  • JAX batch 생성 함수가 PRNG key를 인자로 받는 이유는 무엇인가?
  • deterministic toy data가 학습 루프 검증에 유용한 이유는 무엇인가?