Rank와 Process Group

마지막 수정:

distributedcollectiverankprocess-group

collective operation은 “여러 GPU가 같이 부르는 함수”다. 그래서 먼저 참여자가 누구인지 정해야 한다.

process group = 통신에 참여하는 rank들의 묶음
rank          = 그 묶음 안에서 각 process/GPU가 가진 번호
world size    = process group 안 rank 개수

예를 들어 GPU 3개가 하나의 process group에 들어가면:

world size = 3

rank 0
rank 1
rank 2

Root, source, destination

일부 collective는 중심 rank가 있다.

broadcast: source rank가 값을 보낸다
scatter:   source rank가 조각을 나눠 보낸다
gather:    destination rank가 조각을 모은다
reduce:    destination rank가 합산 결과를 받는다

문서에서는 이 중심 rank를 보통 root라고 부른다. PyTorch API에서는 연산에 따라 src 또는 dst라는 이름을 쓴다.

dist.broadcast(tensor, src=0)
dist.reduce(tensor, dst=0)

반대로 all_* 연산은 특정 root 하나만 결과를 갖는 구조가 아니다.

all_gather:  모든 rank가 full tensor를 갖는다
all_reduce:  모든 rank가 reduced result를 갖는다

왜 모든 rank가 같이 호출해야 하나

collective는 혼자 부르는 함수가 아니다. 같은 process group 안 rank들이 같은 통신 약속에 참여해야 한다.

rank 0: broadcast 호출
rank 1: broadcast 호출
rank 2: broadcast 호출

만약 rank 0만 broadcast를 부르고 rank 1, 2가 다른 일을 하고 있으면 통신 상대가 맞지 않는다. 실제 코드에서는 멈추거나 timeout이 날 수 있다.

PyTorch에서의 기본 모양

실전 코드는 보통 process group을 초기화한 뒤 rank와 world size를 읽는다.

import torch.distributed as dist

dist.init_process_group(backend="nccl")

rank = dist.get_rank()
world_size = dist.get_world_size()

여기서 backend="nccl"은 GPU 간 통신에 자주 쓰는 backend다.

확인

  • rank는 GPU 메모리에 있는 tensor 조각인가, process group 안 번호인가?
  • broadcast(tensor, src=0)에서 src=0은 어떤 rank를 뜻하는가?
  • all_reduce에는 왜 dst가 필요하지 않은가?