Naive GEMV Kernel
GEMV는 matrix-vector multiplication이다.
y[M] = W[M, K] x x[K]
thread 하나가 y[row] 하나를 계산할 수 있다.
y[row] = dot(W[row, :], x[:])
W [M x K]
x
x [K]
=
y [M]
왜 Transformer inference에서 중요한가
prefill에서는 여러 token을 한 번에 처리하므로 GEMM이 중심이다.
prefill:
[T, D] x [D, D] -> [T, D]
decode에서는 새 token 하나를 생성하므로 연산이 GEMV처럼 얇아진다.
decode:
[1, D] x [D, D] -> [1, D]
그래서 LLM inference에서는 GEMM뿐 아니라 GEMV의 memory-bound 성격도 중요하다.
Naive kernel
__global__ void gemv_naive_kernel(
const float* W,
const float* x,
float* y,
int M,
int K
) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M) {
float acc = 0.0f;
for (int col = 0; col < K; col++) {
acc += W[row * K + col] * x[col];
}
y[row] = acc;
}
}
GEMM과의 차이
GEMM:
2D output C[row, col]
2D grid가 자연스러움
GEMV:
1D output y[row]
1D grid가 자연스러움
GEMV는 W를 많이 읽지만 reuse가 작다. 그래서 보통 compute보다 memory bandwidth에 묶인다.
확인
- GEMV output은 왜 1D인가?
- prefill은 왜 GEMM에 가깝고 decode는 왜 GEMV에 가까운가?
- naive GEMV가 memory-bound가 되기 쉬운 이유는 무엇인가?