AllToAll

distributedcollectiveall-to-all

AllToAll은 각 rank가 가진 tensor를 목적지별로 다시 잘라서, 모든 rank가 서로 필요한 조각을 교환하는 collective operation이다.

지금까지 본 gather, reduce 계열과 달리 AllToAll은 값을 더하지 않는다. full tensor를 모든 rank에 복제하지도 않는다. 핵심은 layout을 바꾸는 것이다.

Before: row-sharded A[6 x 6]

rank 0
A00A01A02
rows 0-1, cols 0-5
rank 1
A10A11A12
rows 2-3, cols 0-5
rank 2
A20A21A22
rows 4-5, cols 0-5
->

AllToAll

각 rank의 [2 x 6] row block을 세 개의 [2 x 2] block으로 자르고, column 목적지별로 교환한다.

After: column-sharded A[6 x 6]

rank 0
A00A10A20
rows 0-5, cols 0-1
rank 1
A01A11A21
rows 0-5, cols 2-3
rank 2
A02A12A22
rows 0-5, cols 4-5
AllToAll은 값을 더하지 않고 full tensor를 복제하지도 않는다. 각 rank가 가진 조각을 목적지별로 교환해서 sharding layout을 바꾼다.

하나의 행렬로 보기

전체 행렬이 있다고 하자.

A[6 x 6]

처음에는 row 방향으로 3개 rank에 나뉘어 있다.

Before: row-sharded

rank 0: rows 0-1, cols 0-5 -> [2 x 6]
rank 1: rows 2-3, cols 0-5 -> [2 x 6]
rank 2: rows 4-5, cols 0-5 -> [2 x 6]

그런데 다음 계산은 column 방향 shard를 원한다고 하자.

After: column-sharded

rank 0: rows 0-5, cols 0-1 -> [6 x 2]
rank 1: rows 0-5, cols 2-3 -> [6 x 2]
rank 2: rows 0-5, cols 4-5 -> [6 x 2]

이 변환을 하려면 각 rank가 자기 [2 x 6] block을 column 기준으로 세 개의 [2 x 2] block으로 자른다.

rank 0:
  rows 0-1, cols 0-1 -> rank 0
  rows 0-1, cols 2-3 -> rank 1
  rows 0-1, cols 4-5 -> rank 2

rank 1:
  rows 2-3, cols 0-1 -> rank 0
  rows 2-3, cols 2-3 -> rank 1
  rows 2-3, cols 4-5 -> rank 2

rank 2:
  rows 4-5, cols 0-1 -> rank 0
  rows 4-5, cols 2-3 -> rank 1
  rows 4-5, cols 4-5 -> rank 2

AllToAll은 이 조각들을 목적지 rank로 교환한다. 교환이 끝나면 각 rank는 자신이 맡은 column range에 해당하는 모든 row를 갖게 된다.

다른 collective와 비교

Gather:
  여러 shard를 한 rank에 모은다.

AllGather:
  여러 shard를 모두 모아서 모든 rank가 full tensor를 갖게 한다.

ReduceScatter:
  값을 합친 뒤, 합산 결과를 shard로 나눠 갖는다.

AllToAll:
  값을 합치지 않는다.
  full tensor를 복제하지 않는다.
  조각들을 목적지별로 교환해서 layout을 바꾼다.

언제 쓰는가

AllToAll은 현재 tensor의 sharding 방향과 다음 계산이 원하는 sharding 방향이 다를 때 유용하다.

row-sharded layout -> column-sharded layout

나중에 MoE를 볼 때도 같은 패턴이 나온다. token을 expert가 있는 rank로 보내고, 계산이 끝난 뒤 다시 원래 위치로 돌려보낼 때 AllToAll 계열의 통신이 쓰인다. 하지만 이 카드에서는 MoE보다 먼저, AllToAll을 layout exchange로 이해하는 것이 핵심이다.

확인

  • AllToAll은 값을 더하는가?
  • AllToAll 이후 모든 rank가 full tensor를 갖는가?
  • row-sharded A[6 x 6]을 column-sharded 형태로 바꿀 때 각 rank는 자기 block을 어떤 기준으로 잘라 보내는가?