PyTorch Tensor와 nn.Module 기본기

마지막 수정:

pytorchtransformertensormodule

이 path의 첫 목표는 PyTorch를 “모델을 불러오는 도구”가 아니라 tensor program을 작성하는 도구로 보는 것이다.

Transformer 구현에서 계속 확인할 것은 세 가지다.

shape
device
dtype

nn.Module은 parameter를 가진 계산 블록을 묶는다.

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(128, 128)

    def forward(self, x):
        return self.linear(x)

중요한 습관은 forward 안에서 입력과 출력 shape를 계속 추적하는 것이다.

x: [B, T, D]
out: [B, T, D]

확인

  • PyTorch tensor를 볼 때 shape, device, dtype을 함께 확인해야 하는 이유는 무엇인가?
  • nn.Module은 단순 함수와 무엇이 다른가?
  • Transformer path에서 가장 자주 등장할 기본 shape는 무엇인가?