import math from torch import nn from transformers.models.llama.modeling_llama import * def activation_quant(x, n_bits = 8): q_min = - 2**(n_bits - 1) q_max = 2**(n_bits - 1) - 1 scale = q_max / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) x_quant = (x * scale).round().clamp_(q_min, q_max) / scale return x_quant def weight_quant(w): scale = 1 / w.abs().mean().clamp_(min=1e-5) w_quant = (w * scale).round().clamp_(-1, 1) / scale return w_quant class BitLinear(nn.Linear): def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs ): super(BitLinear, self).__init__(*kargs, **kwargs) def forward(self, x): w = self.weight # a weight tensor with shape [d, k] x = x.to(w.device) RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) x_norm = RMSNorm(x) # A trick for implementing Straight−Through−Estimator (STE) using detach() x_quant = x_norm + (activation_quant(x_norm, 8) - x_norm).detach() w_quant = w + (weight_quant(w) - w).detach() y = F.linear(x_quant, w_quant) return y