ReduceScatter

distributedcollectivereduce-scatter

ReduceScatter는 여러 rank의 값을 먼저 합산하고, 합산된 결과를 다시 shard로 나눠 각 rank에 주는 collective operation이다.

Before ReduceScatter

rank 0
123456
rank 1
102030405060
rank 2
100200300400500600

After ReduceScatter

rank 0
111222
sum shard 0
rank 1
333444
sum shard 1
rank 2
555666
sum shard 2
ReduceScatter는 먼저 합산하고, 합산된 결과를 다시 shard로 나눠 각 rank에 준다.

Before와 after

Before reduce-scatter
rank 0: [1, 2, 3, 4, 5, 6]
rank 1: [10, 20, 30, 40, 50, 60]
rank 2: [100, 200, 300, 400, 500, 600]

Sum result
[111, 222, 333, 444, 555, 666]

After reduce-scatter
rank 0: [111, 222]
rank 1: [333, 444]
rank 2: [555, 666]

AllReduce라면 모든 rank가 [111, 222, 333, 444, 555, 666] 전체를 갖는다. ReduceScatter는 합산 결과를 full tensor로 복제하지 않고, 각 rank가 자기 shard만 갖게 한다.

행렬곱에서 왜 필요한가

inner dimension을 나눠 계산하면 partial C들을 더해야 한다. 그런데 다음 계산이 full C가 아니라 C의 shard만 필요로 한다면, 모든 rank에 full C를 복제할 필요가 없다.

이때 ReduceScatter는 두 단계를 한 번에 수행한다.

1. partial C들을 더한다.
2. 합산된 C를 shard로 나눠 각 rank에 둔다.

Reduce partial C and keep C sharded

rank 0
partial C0+partial C1+partial C2
after
C left|C middle|C right
ReduceScatter는 partial C들을 더하되 full C를 모두에게 복제하지 않는다. 합산된 C를 다시 shard로 나눠 각 rank가 필요한 조각만 갖게 한다.

언제 쓰는가

ReduceScatter는 partial result를 합쳐야 하지만, 합쳐진 결과를 계속 shard된 상태로 유지하고 싶을 때 쓴다.

partial full results on many ranks -> summed shard per rank

확인

  • ReduceScatter는 reduce와 scatter 중 어떤 일을 함께 하는가?
  • AllReduce와 비교했을 때 ReduceScatter가 full result 복제를 피하는 이유는 무엇인가?
  • 다음 계산이 C의 shard만 필요로 한다면 왜 ReduceScatter가 자연스러운가?