Next-token Loss와 Batch 만들기

마지막 수정:

pytorchtransformerlossdataset

Language model 학습은 다음 token을 맞히는 문제로 만들 수 있다.

tokens:  [t0, t1, t2, t3, t4]
input:   [t0, t1, t2, t3]
target:  [t1, t2, t3, t4]

모델은 input을 보고 각 위치에서 다음 token logits를 낸다.

logits: [B, T, V]
target: [B, T]

PyTorch의 cross_entropy는 보통 class dimension을 펼쳐서 계산한다.

loss = F.cross_entropy(
    logits.view(B * T, V),
    targets.view(B * T),
)

이 카드의 산출물은 model forward와 loss 계산이 연결된 최소 batch pipeline이다.

확인

  • next-token prediction에서 input과 target은 어떻게 만들어지는가?
  • logits와 target의 shape는 어떻게 다른가?
  • cross_entropy 전에 logits를 펼치는 이유는 무엇인가?