Ternary and BitNet for CPU Inference

마지막 수정:

Draft
설명: 점검 전 위젯: 점검 전 링크: 점검 전
quantizationternarybitnetcpu-inferencelow-bitinference

Ternary는 weight를 세 값으로 제한하는 표현 방식이다.

weight in {-1, 0, +1}

BitNet은 이 ternary 제약을 전제로 설계하고 훈련한 model family에 가깝다. b1.58이라는 이름은 세 가지 상태를 표현하는 데 필요한 정보량이 log2(3) ~= 1.58 bit라는 점에서 온다.

ternary:
three-value weight representation

BitNet:
model trained to work under ternary-like constraints

CPU-only와 무슨 관련이 있나?

CPU-only inference의 문제는 GPU tensor core가 없다는 것이다. 큰 linear layer를 계속 실행해야 하는데, 범용 CPU에서는 matrix multiply throughput과 memory bandwidth가 모두 제한된다.

Ternary는 여기서 두 가지를 노린다.

1. weight memory traffic 감소
   FP16 16 bits -> ternary about 1.58 bits

2. weight multiplication 제거
   w * x -> add / subtract / skip

수학적으로는 여전히 dot product다. 하지만 kernel이 실제로 하는 일은 달라질 수 있다.

+1 x0 add
0 x1 skip
-1 x2 subtract
+1 x3 add
-1 x4 subtract
0 x5 skip
ordinary dot product sum(w_i * x_i)
ternary kernel view sum(x_pos) - sum(x_neg)
Ternary는 linear algebra를 없애지 않는다. Dot product는 남지만 weight multiplication이 packed code를 읽어 add, subtract, skip을 고르는 문제로 바뀐다. 실제 이득은 이 표현을 아는 CPU kernel이 있을 때 나온다.

Dot product는 남지만 cost model이 바뀐다

일반 linear layer는 이렇게 계산된다.

acc += w_i * x_i

Ternary weight에서는 다음처럼 바뀐다.

w_i = +1 -> acc += x_i
w_i = -1 -> acc -= x_i
w_i =  0 -> skip

즉 일반적인 multiply-accumulate가 packed weight code를 읽고, positive/negative/zero mask에 따라 activation을 더하거나 빼는 문제로 바뀐다.

이 말은 “계산이 공짜가 된다”는 뜻이 아니다.

still needed:
activation loads
packed weight decode
mask/select logic
accumulation

그래서 ternary가 실제로 빠르려면 전용 CPU kernel이 필요하다. Packed ternary weight를 cache에 잘 올리고, SIMD나 bit operation으로 여러 code를 한 번에 처리해야 한다.

PTQ로는 보통 망가진다

Ternary는 INT4보다 훨씬 더 거칠다.

INT4:
16 levels

ternary:
3 levels

기존 BF16/FP16 model의 continuous weight를 사후에 단순히 {-1, 0, +1}로 반올림하면 정보가 너무 많이 사라진다.

그래서 ternary는 일반적인 PTQ 압축 포맷이라기보다 model training constraint에 가깝다.

bad:
train dense model
-> post-hoc ternary rounding

better:
train or further-train while model sees ternary constraint

BitNet 같은 모델이 중요한 이유도 여기에 있다. Ternary constraint를 모델이 학습 과정에서 흡수했기 때문에 CPU-oriented low-bit execution이 의미를 갖는다.

확인

  • Ternary와 BitNet의 관계는 무엇인가?
  • Ternary dot product에서 사라지는 것과 여전히 남는 비용은 무엇인가?
  • 기존 모델을 단순 PTQ로 ternary화하면 위험한 이유는 무엇인가?