Ternary and BitNet for CPU Inference
마지막 수정:
Ternary는 weight를 세 값으로 제한하는 표현 방식이다.
weight in {-1, 0, +1}
BitNet은 이 ternary 제약을 전제로 설계하고 훈련한 model family에 가깝다. b1.58이라는 이름은 세 가지 상태를 표현하는 데 필요한 정보량이 log2(3) ~= 1.58 bit라는 점에서 온다.
ternary:
three-value weight representation
BitNet:
model trained to work under ternary-like constraints
CPU-only와 무슨 관련이 있나?
CPU-only inference의 문제는 GPU tensor core가 없다는 것이다. 큰 linear layer를 계속 실행해야 하는데, 범용 CPU에서는 matrix multiply throughput과 memory bandwidth가 모두 제한된다.
Ternary는 여기서 두 가지를 노린다.
1. weight memory traffic 감소
FP16 16 bits -> ternary about 1.58 bits
2. weight multiplication 제거
w * x -> add / subtract / skip
수학적으로는 여전히 dot product다. 하지만 kernel이 실제로 하는 일은 달라질 수 있다.
sum(w_i * x_i) sum(x_pos) - sum(x_neg) Dot product는 남지만 cost model이 바뀐다
일반 linear layer는 이렇게 계산된다.
acc += w_i * x_i
Ternary weight에서는 다음처럼 바뀐다.
w_i = +1 -> acc += x_i
w_i = -1 -> acc -= x_i
w_i = 0 -> skip
즉 일반적인 multiply-accumulate가 packed weight code를 읽고, positive/negative/zero mask에 따라 activation을 더하거나 빼는 문제로 바뀐다.
이 말은 “계산이 공짜가 된다”는 뜻이 아니다.
still needed:
activation loads
packed weight decode
mask/select logic
accumulation
그래서 ternary가 실제로 빠르려면 전용 CPU kernel이 필요하다. Packed ternary weight를 cache에 잘 올리고, SIMD나 bit operation으로 여러 code를 한 번에 처리해야 한다.
PTQ로는 보통 망가진다
Ternary는 INT4보다 훨씬 더 거칠다.
INT4:
16 levels
ternary:
3 levels
기존 BF16/FP16 model의 continuous weight를 사후에 단순히 {-1, 0, +1}로 반올림하면 정보가 너무 많이 사라진다.
그래서 ternary는 일반적인 PTQ 압축 포맷이라기보다 model training constraint에 가깝다.
bad:
train dense model
-> post-hoc ternary rounding
better:
train or further-train while model sees ternary constraint
BitNet 같은 모델이 중요한 이유도 여기에 있다. Ternary constraint를 모델이 학습 과정에서 흡수했기 때문에 CPU-oriented low-bit execution이 의미를 갖는다.
확인
- Ternary와 BitNet의 관계는 무엇인가?
- Ternary dot product에서 사라지는 것과 여전히 남는 비용은 무엇인가?
- 기존 모델을 단순 PTQ로 ternary화하면 위험한 이유는 무엇인가?