PyTorch Tensor와 nn.Module 기본기
마지막 수정:
이 path의 첫 목표는 PyTorch를 “모델을 불러오는 도구”가 아니라 tensor program을 작성하는 도구로 보는 것이다.
Transformer 구현에서 계속 확인할 것은 세 가지다.
shape
device
dtype
nn.Module은 parameter를 가진 계산 블록을 묶는다.
class Block(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
return self.linear(x)
중요한 습관은 forward 안에서 입력과 출력 shape를 계속 추적하는 것이다.
x: [B, T, D]
out: [B, T, D]
확인
- PyTorch tensor를 볼 때 shape, device, dtype을 함께 확인해야 하는 이유는 무엇인가?
nn.Module은 단순 함수와 무엇이 다른가?- Transformer path에서 가장 자주 등장할 기본 shape는 무엇인가?