행렬 곱셈

linear-algebramatmulperformance

행렬 곱셈은 왼쪽 행렬의 행(row) 과 오른쪽 행렬의 열(column) 을 하나씩 내적해서 출력 행렬의 한 칸을 만든다.

ARm×k,BRk×n,C=ABRm×nA \in \mathbb{R}^{m \times k}, \quad B \in \mathbb{R}^{k \times n}, \quad C = AB \in \mathbb{R}^{m \times n}

출력 원소 하나는 이렇게 계산된다.

Cij=t=1kAitBtjC_{ij} = \sum_{t=1}^{k} A_{it}B_{tj}
A [3 x 2]
x
B [2 x 3]
=
C [3 x 3]

연산량

출력 행렬 C에는 m x n개의 원소가 있다. 각 원소를 만들 때 길이 k짜리 내적을 계산한다.

내적 하나에는 보통 k번의 곱셈과 k - 1번의 덧셈이 필요하다. 성능 분석에서는 multiply-add를 2 FLOPs로 세는 경우가 많으므로, 전체 연산량은 대략 다음과 같다.

FLOPs2mnk\text{FLOPs} \approx 2mnk

정확히 덧셈을 k - 1번으로 세면:

ops=mn(2k1)\text{ops} = mn(2k - 1)

대규모 행렬에서는 2mnk로 생각해도 충분하다.

메모리 이동량

가장 낮은 하한은 입력 행렬 A, B를 한 번씩 읽고 출력 행렬 C를 한 번 쓰는 것이다.

원소 개수 기준:

elements movedmk+kn+mn\text{elements moved} \ge mk + kn + mn

원소 하나가 s bytes라면:

bytes moveds(mk+kn+mn)\text{bytes moved} \ge s(mk + kn + mn)

예를 들어 FP16/BF16이면 s = 2, FP32이면 s = 4다.

이 하한은 각 원소를 메모리에서 한 번만 가져올 수 있다는 이상적인 가정이다. 실제 naive 구현은 C의 원소 하나를 계산할 때마다 필요한 A의 행과 B의 열을 다시 읽을 수 있다.

C에는 mn개의 원소가 있고, 각 원소는 길이 k짜리 내적이다. 내적 하나를 계산하려면 A에서 k개, B에서 k개를 읽는다. 즉 내적 하나당 입력 원소를 2k개 읽는다.

naive read elements2mnk\text{naive read elements} \approx 2mnk

bytes 기준으로는 원소 크기를 곱하면 된다.

naive read bytes2smnk\text{naive read bytes} \approx 2smnk

BF16/FP16이면 s = 2이므로:

naive read bytes4mnk\text{naive read bytes} \approx 4mnk

여기에 출력 C를 쓰는 비용 smn이 추가된다. 그래서 빠른 행렬 곱셈 커널의 핵심은 같은 값을 여러 번 메모리에서 가져오지 않도록 타일링(tiling) 해서 재사용하는 것이다. 이 내용은 block-matrix-multiplication에서 따로 다룬다. 이 카드가 arithmetic-intensityroofline으로 이어지는 이유도 여기에 있다.

확인

  • A[m,k]B[k,n]을 곱하면 출력 C의 shape은 무엇인가?
  • 행렬 곱셈의 FLOPs를 왜 대략 2mnk로 세는가?
  • mk + kn + mn은 어떤 가정에서 나온 메모리 이동량인가?
  • naive 구현의 메모리 읽기량이 왜 2mnk에 가까워질 수 있는가?