Transformer Kernel Map
지금까지 구현한 naive kernel들은 따로 떨어진 장난감이 아니다. Transformer의 실제 연산으로 이어진다.
GEMMQ/K/V projection
MLP
MLP
GEMVdecode projection
Softmaxattention weights
RMSNormblock normalization
Reductionsoftmax / norm / sampling
Transposelayout transform
Kernel 위치
| Kernel | Transformer에서의 위치 |
|---|---|
| Transpose | tensor layout 변환, Q/K shape 정렬 |
| GEMM | Q/K/V projection, MLP, output projection |
| GEMV | decode step의 single-token projection |
| Softmax | attention weights |
| Reduction | softmax, RMSNorm, sampling |
| RMSNorm | transformer block normalization |
Prefill과 Decode
prefill은 prompt 전체를 한 번에 처리한다.
many tokens -> GEMM 중심
decode는 새 token 하나씩 생성한다.
one new token -> GEMV / KV cache read 중심
이 구분은 나중에 vLLM과 PagedAttention을 이해할 때 중요하다.
왜 naive kernel을 먼저 하나
최적화 kernel은 처음부터 읽기 어렵다. naive kernel은 느리지만 수식과 좌표가 직접 보인다.
수식
-> CPU reference
-> naive CUDA kernel
-> PyTorch 연결
-> profile
-> optimization
이 순서가 유지되어야 성능 개선도 신뢰할 수 있다.
확인
- Q/K/V projection은 GEMM인가 GEMV인가?
- decode에서 GEMV가 중요해지는 이유는 무엇인가?
- softmax와 RMSNorm이 공통으로 필요로 하는 kernel pattern은 무엇인가?