Tiled Matrix Multiplication
tiled matmul은 Path 3의 중심 실습이다.
A tile
B tile
C block
naive GEMM의 문제
naive GEMM은 thread 하나가 C[row, col] 하나를 맡는다. 이해하기는 쉽지만, 같은 A row와 B column 조각을 여러 thread가 global memory에서 반복해서 읽는다.
tiled GEMM의 구조
for (int tile = 0; tile < K; tile += TILE) {
load A tile into shared memory;
load B tile into shared memory;
__syncthreads();
for (int kk = 0; kk < TILE; kk++) {
acc += As[ty][kk] * Bs[kk][tx];
}
__syncthreads();
}
핵심은 K 방향을 tile로 쪼개고, 각 tile을 shared memory에 올린 뒤 block 안 thread들이 재사용하는 것이다.
확인
- 왜 tile loop는 K 방향으로 도는가?
As[ty][kk]와Bs[kk][tx]는 어떤 output 원소에 기여하는가?- tile 크기를 키우면 reuse는 늘지만 어떤 비용도 같이 늘어나는가?