Scatter
Scatter는 source rank가 가진 큰 tensor를 여러 조각으로 나눠 각 rank에 하나씩 보내는 collective operation이다.
Before Scatter
rank 0
123456
source rank 1
00
rank 2
00
After Scatter
rank 0
12
shard 0 rank 1
34
shard 1 rank 2
56
shard 2 Before와 after
Before scatter
rank 0: [1, 2, 3, 4, 5, 6]
rank 1: [0, 0]
rank 2: [0, 0]
After scatter
rank 0: [1, 2]
rank 1: [3, 4]
rank 2: [5, 6]
Broadcast가 같은 값을 모두에게 복사한다면, scatter는 큰 값을 쪼개서 서로 다른 조각을 나눠 준다.
행렬곱에서 왜 필요한가
A[6 x 3] x B[3 x 2] = C[6 x 2]
A의 row를 rank별로 나눠 계산하고 싶다면, 먼저 A의 row shard를 각 rank에 보내야 한다. 각 rank는 자기 A shard와 같은 B를 곱해서 C의 row shard를 만든다.
Scatter A rows before local matmul
rank 0
A rows 0-1xB[3 x 2]=C rows 0-1
rank 1
A rows 2-3xB[3 x 2]=C rows 2-3
rank 2
A rows 4-5xB[3 x 2]=C rows 4-5
언제 쓰는가
Scatter는 하나의 큰 tensor를 여러 rank에 나눠 배치하고 싶을 때 쓴다.
one full tensor -> different shard per rank
확인
- Scatter 이후 각 rank는 같은 값을 갖는가, 서로 다른 조각을 갖는가?
- Broadcast와 scatter의 가장 큰 차이는 무엇인가?
A의 row를 여러 rank에 나눠 행렬곱을 하려면 왜 scatter가 자연스러운가?