PyTorch RMSNorm과 Residual Path
마지막 수정:
Transformer block은 큰 계산을 residual path 위에 올린다.
x -> norm -> sublayer -> + x
RMSNorm은 hidden dimension을 따라 root mean square를 계산하고 scale parameter를 곱한다.
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
return self.weight * x * inv_rms
이 구현은 나중에 CUDA RMSNorm kernel과 비교할 PyTorch reference가 된다.
확인
- RMSNorm은 어떤 dimension을 줄여서 통계값을 계산하는가?
- residual connection은 sublayer output을 어디에 더하는가?
- custom CUDA RMSNorm을 연결하기 전에 PyTorch reference가 필요한 이유는 무엇인가?