ReduceScatter
ReduceScatter는 여러 rank의 값을 먼저 합산하고, 합산된 결과를 다시 shard로 나눠 각 rank에 주는 collective operation이다.
Before ReduceScatter
rank 0
123456
rank 1
102030405060
rank 2
100200300400500600
After ReduceScatter
rank 0
111222
sum shard 0 rank 1
333444
sum shard 1 rank 2
555666
sum shard 2 Before와 after
Before reduce-scatter
rank 0: [1, 2, 3, 4, 5, 6]
rank 1: [10, 20, 30, 40, 50, 60]
rank 2: [100, 200, 300, 400, 500, 600]
Sum result
[111, 222, 333, 444, 555, 666]
After reduce-scatter
rank 0: [111, 222]
rank 1: [333, 444]
rank 2: [555, 666]
AllReduce라면 모든 rank가 [111, 222, 333, 444, 555, 666] 전체를 갖는다. ReduceScatter는 합산 결과를 full tensor로 복제하지 않고, 각 rank가 자기 shard만 갖게 한다.
행렬곱에서 왜 필요한가
inner dimension을 나눠 계산하면 partial C들을 더해야 한다. 그런데 다음 계산이 full C가 아니라 C의 shard만 필요로 한다면, 모든 rank에 full C를 복제할 필요가 없다.
이때 ReduceScatter는 두 단계를 한 번에 수행한다.
1. partial C들을 더한다.
2. 합산된 C를 shard로 나눠 각 rank에 둔다.
Reduce partial C and keep C sharded
rank 0
partial C0+partial C1+partial C2
after
C left|C middle|C right
언제 쓰는가
ReduceScatter는 partial result를 합쳐야 하지만, 합쳐진 결과를 계속 shard된 상태로 유지하고 싶을 때 쓴다.
partial full results on many ranks -> summed shard per rank
확인
- ReduceScatter는 reduce와 scatter 중 어떤 일을 함께 하는가?
- AllReduce와 비교했을 때 ReduceScatter가 full result 복제를 피하는 이유는 무엇인가?
- 다음 계산이
C의 shard만 필요로 한다면 왜 ReduceScatter가 자연스러운가?