Manual PyTorch Attention Benchmark
마지막 수정:
Attention benchmark의 첫 기준선은 직접 작성한 PyTorch attention이다.
구조는 명시적이다.
scores = Q @ K^T
scores = scores + causal_mask
probs = softmax(scores)
out = probs @ V
이 구현은 빠르기 위해 쓰는 것이 아니다. 어떤 intermediate tensor가 생기는지 보기 위해 쓴다.
Q, K, V: [B, H, T, Dh]
scores: [B, H, T, T]
probs: [B, H, T, T]
out: [B, H, T, Dh]
Profiler에서 확인할 것은 다음이다.
matmul time
softmax time
mask / fill overhead
peak memory
이 기준선이 있어야 PyTorch SDPA나 optimized attention이 무엇을 줄였는지 볼 수 있다.
확인
- manual attention에서 가장 큰 intermediate tensor는 무엇인가?
- causal mask를 materialize하면 memory와 time에 어떤 영향이 생기는가?
- 이 구현이 느려도 benchmark 기준선으로 중요한 이유는 무엇인가?