AllToAll
AllToAll은 각 rank가 가진 tensor를 목적지별로 다시 잘라서, 모든 rank가 서로 필요한 조각을 교환하는 collective operation이다.
지금까지 본 gather, reduce 계열과 달리 AllToAll은 값을 더하지 않는다. full tensor를 모든 rank에 복제하지도 않는다. 핵심은 layout을 바꾸는 것이다.
Before: row-sharded A[6 x 6]
rank 0
A00A01A02
rows 0-1, cols 0-5 rank 1
A10A11A12
rows 2-3, cols 0-5 rank 2
A20A21A22
rows 4-5, cols 0-5 ->
AllToAll
각 rank의 [2 x 6] row block을 세 개의 [2 x 2] block으로 자르고, column 목적지별로 교환한다.
After: column-sharded A[6 x 6]
rank 0
A00A10A20
rows 0-5, cols 0-1 rank 1
A01A11A21
rows 0-5, cols 2-3 rank 2
A02A12A22
rows 0-5, cols 4-5 하나의 행렬로 보기
전체 행렬이 있다고 하자.
A[6 x 6]
처음에는 row 방향으로 3개 rank에 나뉘어 있다.
Before: row-sharded
rank 0: rows 0-1, cols 0-5 -> [2 x 6]
rank 1: rows 2-3, cols 0-5 -> [2 x 6]
rank 2: rows 4-5, cols 0-5 -> [2 x 6]
그런데 다음 계산은 column 방향 shard를 원한다고 하자.
After: column-sharded
rank 0: rows 0-5, cols 0-1 -> [6 x 2]
rank 1: rows 0-5, cols 2-3 -> [6 x 2]
rank 2: rows 0-5, cols 4-5 -> [6 x 2]
이 변환을 하려면 각 rank가 자기 [2 x 6] block을 column 기준으로 세 개의 [2 x 2] block으로 자른다.
rank 0:
rows 0-1, cols 0-1 -> rank 0
rows 0-1, cols 2-3 -> rank 1
rows 0-1, cols 4-5 -> rank 2
rank 1:
rows 2-3, cols 0-1 -> rank 0
rows 2-3, cols 2-3 -> rank 1
rows 2-3, cols 4-5 -> rank 2
rank 2:
rows 4-5, cols 0-1 -> rank 0
rows 4-5, cols 2-3 -> rank 1
rows 4-5, cols 4-5 -> rank 2
AllToAll은 이 조각들을 목적지 rank로 교환한다. 교환이 끝나면 각 rank는 자신이 맡은 column range에 해당하는 모든 row를 갖게 된다.
다른 collective와 비교
Gather:
여러 shard를 한 rank에 모은다.
AllGather:
여러 shard를 모두 모아서 모든 rank가 full tensor를 갖게 한다.
ReduceScatter:
값을 합친 뒤, 합산 결과를 shard로 나눠 갖는다.
AllToAll:
값을 합치지 않는다.
full tensor를 복제하지 않는다.
조각들을 목적지별로 교환해서 layout을 바꾼다.
언제 쓰는가
AllToAll은 현재 tensor의 sharding 방향과 다음 계산이 원하는 sharding 방향이 다를 때 유용하다.
row-sharded layout -> column-sharded layout
나중에 MoE를 볼 때도 같은 패턴이 나온다. token을 expert가 있는 rank로 보내고, 계산이 끝난 뒤 다시 원래 위치로 돌려보낼 때 AllToAll 계열의 통신이 쓰인다. 하지만 이 카드에서는 MoE보다 먼저, AllToAll을 layout exchange로 이해하는 것이 핵심이다.
확인
- AllToAll은 값을 더하는가?
- AllToAll 이후 모든 rank가 full tensor를 갖는가?
- row-sharded
A[6 x 6]을 column-sharded 형태로 바꿀 때 각 rank는 자기 block을 어떤 기준으로 잘라 보내는가?