Custom RMSNorm Extension을 PyTorch Transformer에 연결하기

마지막 수정:

pytorchcudaextensionrmsnormbenchmark

이 카드는 기존 CUDA 트랙과 PyTorch Transformer path가 만나는 지점이다.

먼저 PyTorch reference RMSNorm이 있다.

PyTorch RMSNorm
  -> correctness reference

그다음 CUDA extension RMSNorm을 같은 interface로 감싼다.

CustomRMSNorm(nn.Module)
  -> autograd.Function
  -> C++ binding
  -> CUDA kernel

마지막으로 Transformer block 안의 RMSNorm만 교체한다.

Block(norm=PyTorchRMSNorm)
Block(norm=CustomCUDARMSNorm)

비교는 correctness와 speed를 함께 본다.

확인

  • custom kernel을 모델에 넣을 때 correctness를 먼저 확인해야 하는 이유는 무엇인가?
  • RMSNorm만 빨라져도 전체 step time이 크게 줄지 않을 수 있는 이유는 무엇인가?
  • 이 실험은 CUDA 트랙과 PyTorch 트랙을 어떻게 연결하는가?