Reduction Optimization

cudareductionsoftmaxrmsnorm

reduction은 여러 값을 하나의 값으로 줄인다. softmax의 max/sum, RMSNorm의 sum of squares가 모두 reduction이다.

25143607
many inputs -> one statistic
sum28
max7
mean3.5
Reduction은 여러 원소를 읽어 하나의 통계량을 만든다.

naive reduction의 한계

thread 하나가 row 전체를 순회하면 병렬성이 부족하다.

float sum = 0.0f;
for (int i = 0; i < cols; i++) {
    sum += x[row * cols + i];
}

최적화된 kernel에서는 여러 thread가 row 일부를 읽고, partial sum을 만든 뒤 block 안에서 줄인다.

block reduction 스케치

partial[threadIdx.x] = local_sum;
__syncthreads();

for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
    if (threadIdx.x < stride) {
        partial[threadIdx.x] += partial[threadIdx.x + stride];
    }
    __syncthreads();
}

확인

  • reduction은 output 원소 하나가 input 여러 개에 의존하는 패턴이다.
  • softmax는 max reduction과 sum reduction이 모두 필요하다.
  • shared memory reduction에는 synchronization 비용이 있다.