CUDA Attention과 FlashAttention
이 경로는 Path 3의 최적화 재료를 attention에 적용한다.
목표는 FlashAttention을 외워서 구현하는 것이 아니다. 먼저 naive attention이 왜 큰 score matrix를 만들고, 그 matrix가 왜 memory bottleneck이 되는지 본다. 그 다음 online softmax와 tiling을 합쳐 score matrix를 HBM에 저장하지 않는 방식으로 넘어간다.
끝까지 가면 FlashAttention forward의 핵심을 다음 한 문장으로 설명할 수 있어야 한다.
QK^T 전체를 저장하지 않고, K/V tile을 흘려보내며 row별 softmax 상태와 output을 갱신한다.
- Attention Score Matrix — QK^T가 sequence length 제곱 크기의 score matrix를 만들고 memory bottleneck을 일으키는 이유를 본다.
- Softmax와 Value Mixing — attention score row를 softmax로 바꾼 뒤 V row들을 가중합하는 과정을 CUDA 관점에서 읽는다.
- Online Softmax — 전체 score row를 한 번에 보관하지 않고 tile별 max와 sum 상태를 갱신하는 online softmax를 이해한다.
- Attention Tiling — Q block과 K/V tile로 attention을 나누어 SRAM 안에서 score와 softmax 상태를 처리하는 구조를 본다.
- FlashAttention Forward — FlashAttention forward가 QK^T, softmax, PV를 하나의 tiled kernel 흐름으로 합치는 방식을 이해한다.
- FlashAttention IO Analysis — FlashAttention의 핵심 이득이 FLOP 감소가 아니라 HBM read/write 감소라는 점을 이해한다.
- Causal Mask와 Tiled Attention — decoder-only attention에서 미래 token을 보지 않도록 mask를 tile 계산 안에 넣는 방식을 이해한다.
- PyTorch에서 Attention Kernel 비교하기 — naive attention, torch SDPA, FlashAttention 계열 kernel을 correctness와 benchmark 관점에서 비교한다.