Custom RMSNorm Extension을 PyTorch Transformer에 연결하기
마지막 수정:
이 카드는 기존 CUDA 트랙과 PyTorch Transformer path가 만나는 지점이다.
먼저 PyTorch reference RMSNorm이 있다.
PyTorch RMSNorm
-> correctness reference
그다음 CUDA extension RMSNorm을 같은 interface로 감싼다.
CustomRMSNorm(nn.Module)
-> autograd.Function
-> C++ binding
-> CUDA kernel
마지막으로 Transformer block 안의 RMSNorm만 교체한다.
Block(norm=PyTorchRMSNorm)
Block(norm=CustomCUDARMSNorm)
비교는 correctness와 speed를 함께 본다.
확인
- custom kernel을 모델에 넣을 때 correctness를 먼저 확인해야 하는 이유는 무엇인가?
- RMSNorm만 빨라져도 전체 step time이 크게 줄지 않을 수 있는 이유는 무엇인가?
- 이 실험은 CUDA 트랙과 PyTorch 트랙을 어떻게 연결하는가?