PyTorch DDP Training Loop

마지막 수정:

pytorchdistributedddptraining

DDP는 같은 model replica를 여러 GPU에 두고, 각 process가 서로 다른 mini-batch shard를 처리하게 한다.

rank 0: model replica + data shard 0
rank 1: model replica + data shard 1
rank 2: model replica + data shard 2
rank 3: model replica + data shard 3

PyTorch에서는 model을 GPU로 옮긴 뒤 DistributedDataParallel로 감싼다.

model = Transformer(config).to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
    model,
    device_ids=[local_rank],
)

데이터도 rank별로 나눠야 한다.

sampler = torch.utils.data.DistributedSampler(
    dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
)

DDP는 backward 중에 gradient bucket이 준비되면 all-reduce를 걸어준다. 그래서 단일 GPU loop와 거의 같은 코드처럼 보이지만, 실제로는 backward 뒤에서 통신이 진행된다.

gradient accumulation을 할 때는 중간 step에서 통신을 끄는 것이 중요하다.

with model.no_sync():
    loss.backward()

확인

  • DDP에서 model replica는 rank마다 같은가, 다른가?
  • DistributedSampler가 없으면 어떤 문제가 생길 수 있는가?
  • gradient accumulation 중간 step에서 no_sync()가 필요한 이유는 무엇인가?