Attention Tiling

cudaattentiontilingshared-memory

attention tiling은 GEMM tiling과 닮았지만 softmax 상태가 추가된다.

Q block
K tile
scores in SRAM
V tile
O update

기본 구조

for each Q block:
  initialize m, l, O
  for each K/V tile:
    compute score tile QK^T
    update online softmax state
    update O with V tile

score tile은 잠깐 필요하다. 전체 score matrix를 HBM에 저장하지 않는다.

GEMM tiling과 다른 점

GEMM은 partial sum만 누적하면 된다. attention은 softmax 때문에 max와 normalization 상태를 같이 관리해야 한다.

확인

  • Q block은 반복 바깥에 있고 K/V tile은 왜 안쪽에서 순회하는가?
  • score tile을 저장하지 않는다는 말은 어떤 memory에 저장하지 않는다는 뜻인가?
  • softmax 상태 때문에 attention tiling이 matmul tiling보다 복잡해지는 이유는 무엇인가?