Transformer Kernel Map

cudatransformerkernel-map

지금까지 구현한 naive kernel들은 따로 떨어진 장난감이 아니다. Transformer의 실제 연산으로 이어진다.

GEMMQ/K/V projection
MLP
GEMVdecode projection
Softmaxattention weights
RMSNormblock normalization
Reductionsoftmax / norm / sampling
Transposelayout transform
Path 2의 naive kernels는 Transformer block 안의 실제 연산으로 이어진다.

Kernel 위치

KernelTransformer에서의 위치
Transposetensor layout 변환, Q/K shape 정렬
GEMMQ/K/V projection, MLP, output projection
GEMVdecode step의 single-token projection
Softmaxattention weights
Reductionsoftmax, RMSNorm, sampling
RMSNormtransformer block normalization

Prefill과 Decode

prefill은 prompt 전체를 한 번에 처리한다.

many tokens -> GEMM 중심

decode는 새 token 하나씩 생성한다.

one new token -> GEMV / KV cache read 중심

이 구분은 나중에 vLLM과 PagedAttention을 이해할 때 중요하다.

왜 naive kernel을 먼저 하나

최적화 kernel은 처음부터 읽기 어렵다. naive kernel은 느리지만 수식과 좌표가 직접 보인다.

수식
-> CPU reference
-> naive CUDA kernel
-> PyTorch 연결
-> profile
-> optimization

이 순서가 유지되어야 성능 개선도 신뢰할 수 있다.

확인

  • Q/K/V projection은 GEMM인가 GEMV인가?
  • decode에서 GEMV가 중요해지는 이유는 무엇인가?
  • softmax와 RMSNorm이 공통으로 필요로 하는 kernel pattern은 무엇인가?