PyTorch Decoder-only Transformer 조립
마지막 수정:
이제 작은 decoder-only Transformer를 조립한다.
token ids
-> token embedding
-> position information
-> decoder blocks
-> final norm
-> lm head
-> logits
PyTorch에서는 block을 nn.ModuleList로 쌓을 수 있다.
self.blocks = nn.ModuleList([
DecoderBlock(config) for _ in range(n_layers)
])
forward에서는 block을 순서대로 통과한다.
for block in self.blocks:
x = block(x)
이 카드가 끝나면 아직 잘 학습된 모델은 아니지만, next-token logits를 내는 최소 모델이 생긴다.
확인
ModuleList를 쓰는 이유는 무엇인가?- final norm은 LM head 앞에서 어떤 역할을 하는가?
- 이 모델의 최종 출력 shape는 무엇인가?