Attention Tiling
attention tiling은 GEMM tiling과 닮았지만 softmax 상태가 추가된다.
Q block
K tile
scores in SRAM
V tile
O update
기본 구조
for each Q block:
initialize m, l, O
for each K/V tile:
compute score tile QK^T
update online softmax state
update O with V tile
score tile은 잠깐 필요하다. 전체 score matrix를 HBM에 저장하지 않는다.
GEMM tiling과 다른 점
GEMM은 partial sum만 누적하면 된다. attention은 softmax 때문에 max와 normalization 상태를 같이 관리해야 한다.
확인
- Q block은 반복 바깥에 있고 K/V tile은 왜 안쪽에서 순회하는가?
- score tile을 저장하지 않는다는 말은 어떤 memory에 저장하지 않는다는 뜻인가?
- softmax 상태 때문에 attention tiling이 matmul tiling보다 복잡해지는 이유는 무엇인가?