PyTorch Decoder-only Transformer 조립

마지막 수정:

pytorchtransformerdecoder-onlylanguage-model

이제 작은 decoder-only Transformer를 조립한다.

token ids
  -> token embedding
  -> position information
  -> decoder blocks
  -> final norm
  -> lm head
  -> logits

PyTorch에서는 block을 nn.ModuleList로 쌓을 수 있다.

self.blocks = nn.ModuleList([
    DecoderBlock(config) for _ in range(n_layers)
])

forward에서는 block을 순서대로 통과한다.

for block in self.blocks:
    x = block(x)

이 카드가 끝나면 아직 잘 학습된 모델은 아니지만, next-token logits를 내는 최소 모델이 생긴다.

확인

  • ModuleList를 쓰는 이유는 무엇인가?
  • final norm은 LM head 앞에서 어떤 역할을 하는가?
  • 이 모델의 최종 출력 shape는 무엇인가?