Scatter

distributedcollectivescatter

Scatter는 source rank가 가진 큰 tensor를 여러 조각으로 나눠 각 rank에 하나씩 보내는 collective operation이다.

Before Scatter

rank 0
123456
source
rank 1
00
rank 2
00

After Scatter

rank 0
12
shard 0
rank 1
34
shard 1
rank 2
56
shard 2
Scatter는 source rank의 큰 tensor를 여러 조각으로 나눠 각 rank에 하나씩 보낸다.

Before와 after

Before scatter
rank 0: [1, 2, 3, 4, 5, 6]
rank 1: [0, 0]
rank 2: [0, 0]

After scatter
rank 0: [1, 2]
rank 1: [3, 4]
rank 2: [5, 6]

Broadcast가 같은 값을 모두에게 복사한다면, scatter는 큰 값을 쪼개서 서로 다른 조각을 나눠 준다.

행렬곱에서 왜 필요한가

A[6 x 3] x B[3 x 2] = C[6 x 2]

A의 row를 rank별로 나눠 계산하고 싶다면, 먼저 A의 row shard를 각 rank에 보내야 한다. 각 rank는 자기 A shard와 같은 B를 곱해서 C의 row shard를 만든다.

Scatter A rows before local matmul

rank 0
A rows 0-1xB[3 x 2]=C rows 0-1
rank 1
A rows 2-3xB[3 x 2]=C rows 2-3
rank 2
A rows 4-5xB[3 x 2]=C rows 4-5
A를 row 방향으로 나눠 각 rank에 보내면, 각 rank는 자기 A shard와 같은 B를 곱해 C의 row shard를 계산한다.

언제 쓰는가

Scatter는 하나의 큰 tensor를 여러 rank에 나눠 배치하고 싶을 때 쓴다.

one full tensor -> different shard per rank

확인

  • Scatter 이후 각 rank는 같은 값을 갖는가, 서로 다른 조각을 갖는가?
  • Broadcast와 scatter의 가장 큰 차이는 무엇인가?
  • A의 row를 여러 rank에 나눠 행렬곱을 하려면 왜 scatter가 자연스러운가?