PyTorch RMSNorm과 Residual Path

마지막 수정:

pytorchtransformerrmsnormresidual

Transformer block은 큰 계산을 residual path 위에 올린다.

x -> norm -> sublayer -> + x

RMSNorm은 hidden dimension을 따라 root mean square를 계산하고 scale parameter를 곱한다.

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        inv_rms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
        return self.weight * x * inv_rms

이 구현은 나중에 CUDA RMSNorm kernel과 비교할 PyTorch reference가 된다.

확인

  • RMSNorm은 어떤 dimension을 줄여서 통계값을 계산하는가?
  • residual connection은 sublayer output을 어디에 더하는가?
  • custom CUDA RMSNorm을 연결하기 전에 PyTorch reference가 필요한 이유는 무엇인가?