PyTorch Token Embedding과 LM Head

마지막 수정:

pytorchtransformerembeddinglogits

Decoder-only Transformer의 양끝은 token embedding과 LM head다.

token ids: [B, T]
embedding: [B, T, D]
lm head:   [B, T, V]

PyTorch에서는 보통 nn.Embedding으로 token id를 vector로 바꾼다.

self.token_embedding = nn.Embedding(vocab_size, d_model)

x = self.token_embedding(input_ids)

마지막 hidden state는 vocab 전체에 대한 logits로 바뀐다.

self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

logits = self.lm_head(x)

이 카드의 산출물은 아직 attention이 없는 작은 model shell이다.

input_ids -> embedding -> logits

확인

  • input_ids와 embedding output의 shape는 어떻게 다른가?
  • LM head는 hidden state를 무엇으로 바꾸는가?
  • logits는 확률인가, 확률이 되기 전 점수인가?