행렬 곱셈
행렬 곱셈은 왼쪽 행렬의 행(row) 과 오른쪽 행렬의 열(column) 을 하나씩 내적해서 출력 행렬의 한 칸을 만든다.
출력 원소 하나는 이렇게 계산된다.
연산량
출력 행렬 C에는 m x n개의 원소가 있다. 각 원소를 만들 때 길이 k짜리 내적을 계산한다.
내적 하나에는 보통 k번의 곱셈과 k - 1번의 덧셈이 필요하다. 성능 분석에서는 multiply-add를 2 FLOPs로 세는 경우가 많으므로, 전체 연산량은 대략 다음과 같다.
정확히 덧셈을 k - 1번으로 세면:
대규모 행렬에서는 2mnk로 생각해도 충분하다.
메모리 이동량
가장 낮은 하한은 입력 행렬 A, B를 한 번씩 읽고 출력 행렬 C를 한 번 쓰는 것이다.
원소 개수 기준:
원소 하나가 s bytes라면:
예를 들어 FP16/BF16이면 s = 2, FP32이면 s = 4다.
이 하한은 각 원소를 메모리에서 한 번만 가져올 수 있다는 이상적인 가정이다. 실제 naive 구현은 C의 원소 하나를 계산할 때마다 필요한 A의 행과 B의 열을 다시 읽을 수 있다.
C에는 mn개의 원소가 있고, 각 원소는 길이 k짜리 내적이다. 내적 하나를 계산하려면 A에서 k개, B에서 k개를 읽는다. 즉 내적 하나당 입력 원소를 2k개 읽는다.
bytes 기준으로는 원소 크기를 곱하면 된다.
BF16/FP16이면 s = 2이므로:
여기에 출력 C를 쓰는 비용 smn이 추가된다. 그래서 빠른 행렬 곱셈 커널의 핵심은 같은 값을 여러 번 메모리에서 가져오지 않도록 타일링(tiling) 해서 재사용하는 것이다. 이 내용은 block-matrix-multiplication에서 따로 다룬다. 이 카드가 arithmetic-intensity와 roofline으로 이어지는 이유도 여기에 있다.
확인
A[m,k]와B[k,n]을 곱하면 출력C의 shape은 무엇인가?- 행렬 곱셈의 FLOPs를 왜 대략
2mnk로 세는가? mk + kn + mn은 어떤 가정에서 나온 메모리 이동량인가?- naive 구현의 메모리 읽기량이 왜
2mnk에 가까워질 수 있는가?