Reduction Optimization
reduction은 여러 값을 하나의 값으로 줄인다. softmax의 max/sum, RMSNorm의 sum of squares가 모두 reduction이다.
25143607
many inputs -> one statistic
sum28
max7
mean3.5
naive reduction의 한계
thread 하나가 row 전체를 순회하면 병렬성이 부족하다.
float sum = 0.0f;
for (int i = 0; i < cols; i++) {
sum += x[row * cols + i];
}
최적화된 kernel에서는 여러 thread가 row 일부를 읽고, partial sum을 만든 뒤 block 안에서 줄인다.
block reduction 스케치
partial[threadIdx.x] = local_sum;
__syncthreads();
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) {
partial[threadIdx.x] += partial[threadIdx.x + stride];
}
__syncthreads();
}
확인
- reduction은 output 원소 하나가 input 여러 개에 의존하는 패턴이다.
- softmax는 max reduction과 sum reduction이 모두 필요하다.
- shared memory reduction에는 synchronization 비용이 있다.