Expert Parallel MoE Dispatch
마지막 수정:
Expert Parallelism은 Picotron에서 거의 비어 있는 축이다. 그래서 이 카드는 Picotron을 억지로 읽지 않고, Picotron의 좌표계에서 EP axis가 왜 필요한지를 Nanotron과 Megatron으로 채운다.
Tokens
token 1token 2token 3Router
top-2 선택
Experts
MoE layer의 forward는 dense MLP와 다르다.
dense MLP:
tokens -> same MLP -> outputs
MoE:
tokens -> router -> selected experts -> combine outputs
EP는 expert parameter를 rank들에 나눠 둔다.
expert_parallel_size = 4
num_experts = 8
rank 0: expert 0, 1
rank 1: expert 2, 3
rank 2: expert 4, 5
rank 3: expert 6, 7
이제 token이 expert 6으로 routed되면, 그 token activation은 expert 6을 가진 rank 3으로 이동해야 한다.
EP의 본질은 token dispatch다
EP는 parameter sharding만으로 끝나지 않는다. 매 step마다 router 결과가 바뀌므로 token 이동도 동적으로 바뀐다.
1. router가 token별 top-k expert를 고른다
2. token을 expert owner rank별 bucket으로 묶는다
3. all-to-all로 token activation을 expert rank에 보낸다
4. 각 rank가 local experts를 실행한다
5. expert output을 원래 token 위치로 되돌린다
이 흐름을 all-to-all 관점으로 보면 더 명확하다.
Before: row-sharded A[6 x 6]
AllToAll
각 rank의 [2 x 6] row block을 세 개의 [2 x 2] block으로 자르고, column 목적지별로 교환한다.
After: column-sharded A[6 x 6]
그림은 일반적인 layout 변환을 보여주지만, MoE에서는 행/열 block 대신 “token bucket”이 이동한다고 보면 된다.
before:
rank마다 자기 batch token을 가짐
after dispatch all-to-all:
rank마다 자기 local expert가 처리해야 할 token을 가짐
after combine all-to-all:
rank마다 원래 batch token output을 다시 가짐
Nanotron은 작은 MoE forward를 보여준다
Nanotron의 Qwen2MoELayer는 읽기 좋은 축약판이다.
routing_weights, routing_indices = self.router(hidden_states)
dispatched_inputs, inverse_permute_mapping, num_tokens_per_expert = self._dispatch_tokens(
hidden_states, routing_indices
)
expert_outputs = self.experts(dispatched_inputs, num_tokens_per_expert)
output = self._combine_expert_outputs(
expert_outputs["hidden_states"], inverse_permute_mapping, routing_weights
)
Router는 fp32 weight로 expert logits를 만들고 top-k를 고른다.
logits = F.linear(x.to(torch.float32), self.weight, bias=None)
routing_weights = F.softmax(logits, dim=-1, dtype=torch.float32)
routing_weights, routing_indices = torch.topk(routing_weights, k=top_k, dim=-1)
그 다음 token을 expert 순서로 permute한다.
num_tokens_per_expert = torch.bincount(routing_indices.flatten(), minlength=num_local_experts)
dispatched_inputs, inverse_permute_mapping = ops.permute(hidden_states, routing_indices)
Nanotron 코드의 한계도 중요하다. 주석에 “full implementation would handle communication between devices”라고 되어 있다. 즉 이 코드는 MoE local permutation과 grouped GEMM의 구조는 보여주지만, production EP all-to-all까지 완성된 교육용 구현은 아니다.
Megatron은 dispatcher를 별도 subsystem으로 둔다
Megatron의 MoELayer는 router, experts, dispatcher를 분리한다.
self.router = TopKRouter(...)
self.experts = ...
self.token_dispatcher = MoEAllGatherTokenDispatcher(...)
# or MoEAlltoAllTokenDispatcher(...)
# or MoEFlexTokenDispatcher(...)
forward 흐름도 dispatcher 중심이다.
router
-> dispatch_preprocess
-> token_dispatch
-> dispatch_postprocess
-> experts
-> combine_preprocess
-> token_combine
-> combine_postprocess
Megatron이 이렇게 나누는 이유는 MoE dispatch가 단순한 all_to_all 하나가 아니기 때문이다.
router load balancing
token dropping / capacity
permute fusion
all-gather vs all-to-all vs flex dispatcher
DeepEP / HybridEP fused dispatch
grouped GEMM
shared experts
CUDA graph capture constraints
작은 교육용 구현에서는 permute -> all-to-all -> local expert -> all-to-all -> unpermute만 보이면 충분하다. Megatron은 이 경로의 병목을 줄이기 위해 dispatcher를 고도화한다.
실습
moe_dispatch_sim.py는 실제 expert MLP를 실행하지 않는다. router 결과가 어떻게 expert owner rank별 send bucket으로 바뀌는지 출력한다.
python3 labs/large-scale-training-parallelism/moe_dispatch_sim.py --ep 4 --experts 8 --tokens 12 --top-k 2
출력에서 봐야 할 것은 세 가지다.
token -> experts -> owner ranks
send buckets for all-to-all
tokens_per_expert
tokens_per_expert가 균등하지 않으면 특정 expert rank가 병목이 된다. 그래서 production MoE에는 aux loss, capacity factor, token dropping, load balancing router가 붙는다.
읽는 순서
EP/MoE는 다음 순서로 읽는 것이 좋다.
- 기존
mixture-of-experts카드로 router와 top-k expert 선택을 이해한다. - 이 카드에서 expert owner rank와 token dispatch 문제를 잡는다.
- Nanotron
Router와Qwen2MoELayer._core_forward를 읽어 local permutation 흐름을 본다. - Megatron
MoELayer에서 dispatcher가 별도 객체로 분리되는 이유를 확인한다. - Megatron
token_dispatcher.py에서 all-gather, all-to-all, flex dispatcher의 차이를 본다. - 마지막으로
fused_a2a.py, DeepEP/HybridEP, paged stash 같은 production 최적화는 병목을 이해한 뒤 읽는다.
EP를 이해했다는 신호는 “expert를 여러 GPU에 둔다”가 아니다. router가 만든 token-to-expert mapping을 어떻게 communication plan으로 바꾸고, 다시 원래 token layout으로 복구하는지를 설명할 수 있어야 한다.