Mixed Precision Training

마지막 수정:

precisiontrainingmixed-precisionfp16bf16

Mixed precision training은 이름 그대로 여러 precision을 섞어 쓰는 학습 방식이다.

핵심은 이것이다.

빠르게 계산해도 되는 곳은 낮은 precision을 쓰고, 수치 안정성이 중요한 곳은 높은 precision을 유지한다.

Forward / GEMM BF16 / FP16 / FP8

큰 행렬곱은 낮은 precision으로 빠르게 계산한다.

Master weights FP32 or BF16

업데이트 누적은 더 안정적인 형식으로 유지한다.

Grad accumulation often FP32

작은 gradient가 사라지지 않게 누적 정밀도를 높인다.

Optimizer states higher precision

Adam의 momentum/variance는 수치 안정성이 중요하다.

Mixed precision은 모든 tensor를 낮은 precision으로 바꾸는 것이 아니라, 빠른 곳과 안정적이어야 하는 곳을 나눠 쓰는 전략이다.

그냥 전부 FP16으로 바꾸면 안 되나?

대개 잘 안 된다. 이유는 작은 값이 사라지기 쉽기 때문이다.

small gradient
small weight update
large sum / average accumulation

FP16은 range가 좁아서 아주 작은 값이 0으로 underflow될 수 있고, 아주 큰 값은 overflow될 수 있다.

BF16은 range가 넓어서 FP16보다 안정적인 경우가 많지만, mantissa가 적어서 precision은 더 거칠다.

보통 무엇을 섞나

일반적인 그림은 이렇다.

GEMM / activation:
  FP16 or BF16

master weights:
  FP32 또는 더 높은 precision

gradient accumulation:
  FP32가 자주 사용됨

optimizer states:
  FP32가 자주 사용됨

이 구조가 training-memory-overview와 연결된다. parameters, gradients, optimizer states가 각각 어떤 dtype으로 저장되는지에 따라 메모리 계산이 달라진다.

확인

  • mixed precision은 모든 tensor를 낮은 precision으로 바꾸는 것인가?
  • 낮은 precision 계산을 쓰면서도 master weights를 더 높은 precision으로 둘 수 있는 이유는 무엇인가?
  • optimizer states가 높은 precision을 요구하는 이유는 무엇인가?