Distributed Matmul: Sharded Matrices and Collective Operations
이 경로는 분산 행렬곱을 이해하기 위한 경로입니다.
먼저 행렬을 여러 GPU/TPU에 나누면 왜 통신이 필요한지 직관적으로 봅니다. 그다음 개별 collective operation, sharding notation, device mesh, sharded matmul case를 차례로 쌓아갈 예정입니다.
- 왜 Collective Operation이 필요한가 — 행렬곱을 여러 GPU에 나누면 계산 전후에 데이터를 모으거나 합쳐야 하는 이유를 숫자 예시로 이해한다.
- Device Mesh and Tensor Sharding — 여러 GPU를 mesh로 이름 붙이고, tensor의 차원을 mesh 축에 나눠 local shard를 만드는 방법을 이해한다.
- Broadcast — 하나의 source rank가 가진 tensor를 모든 rank에 복사하는 collective operation.
- Scatter — 하나의 source tensor를 여러 shard로 나눠 각 rank에 하나씩 보내는 collective operation.
- Gather — 여러 rank가 가진 shard를 하나의 destination rank에 모으는 collective operation.
- AllGather — 여러 rank의 shard를 모아 모든 rank가 full tensor를 갖게 하는 collective operation.
- Reduce — 여러 rank의 값을 합산해 하나의 destination rank에만 결과를 만드는 collective operation.
- AllReduce — 여러 rank의 값을 합산한 뒤 모든 rank가 같은 합산 결과를 갖게 하는 collective operation.
- ReduceScatter — 여러 rank의 값을 합산하고, 합산 결과를 shard로 나눠 각 rank에 주는 collective operation.
- AllToAll — 각 rank가 가진 조각을 목적지별로 서로 교환해 sharding layout을 바꾸는 collective operation.