Ring AllReduce

마지막 수정:

distributedcollectiveall-reducering

AllReduce는 의미만 보면 간단하다.

여러 rank의 값을 합산한다
그 합산 결과를 모든 rank가 갖는다

하지만 실제 구현에서 모든 rank가 한 rank로 한꺼번에 몰려가면 병목이 생길 수 있다. 그래서 많이 쓰는 구현 방식 중 하나가 Ring AllReduce다.

1. ReduceScatter

rank 0 A0+B0+C0
rank 1 A1+B1+C1
rank 2 A2+B2+C2

각 rank는 전체 합산 결과가 아니라, 자신이 맡은 reduced chunk 하나만 갖는다.

then

2. AllGather

rank 0 S0 S1 S2
rank 1 S0 S1 S2
rank 2 S0 S1 S2

reduced chunk들을 다시 돌려서 모든 rank가 full reduced tensor를 갖게 한다.

AllReduce result ReduceScatter + AllGather = 같은 합산 결과를 모두에게 복제
Ring AllReduce는 모든 rank가 한 곳으로 몰리는 구조가 아니라, 이웃과 chunk를 주고받으며 reduce와 gather를 단계적으로 수행한다.

두 단계로 보기

Ring AllReduce는 크게 두 단계로 볼 수 있다.

AllReduce = ReduceScatter + AllGather

첫 번째 단계는 ReduceScatter다.

1. tensor를 rank 수만큼 chunk로 나눈다.
2. 이웃 rank와 chunk를 주고받는다.
3. 받은 chunk를 자기 chunk와 더한다.
4. 반복 후 각 rank는 fully reduced chunk 하나를 갖는다.

이 시점에는 모든 rank가 full result를 갖지 않는다.

rank 0: reduced chunk 0
rank 1: reduced chunk 1
rank 2: reduced chunk 2

두 번째 단계는 AllGather다.

1. 각 rank가 가진 reduced chunk를 이웃에게 보낸다.
2. 받은 chunk를 계속 전달한다.
3. 모든 rank가 모든 reduced chunk를 모으면 full AllReduce 결과가 된다.

비용 직관

전체 tensor 길이를 K, rank 수를 N이라고 하자. Ring AllReduce에서 각 rank는 한 번에 전체 K를 보내는 것이 아니라 K / N 크기의 chunk를 여러 번 보낸다.

ReduceScatter 단계에서 대략:

(N - 1)번 전송 x K/N

AllGather 단계에서도 대략:

(N - 1)번 전송 x K/N

그래서 rank 하나가 주고받는 총량은:

2 x (N - 1) x K/N

N이 커지면 대략 2K에 가까워진다. 중요한 직관은 이것이다.

AllReduce는 하나의 마법 같은 연산이 아니라, reduced shard를 만든 뒤 그 shard들을 다시 모두에게 퍼뜨리는 흐름으로 구현될 수 있다.

확인

  • Ring AllReduce의 첫 단계는 왜 ReduceScatter인가?
  • ReduceScatter가 끝난 직후 각 rank는 full result를 갖는가?
  • AllGather 단계는 무엇을 모든 rank에 복제하는가?