from transformers import LlamaForCausalLM import torch from torch import nn class ScaledLinear(nn.Linear): def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features, bias=bias) self.output_scales = nn.Parameter(torch.ones((1, out_features))) assert bias == False, "bias not supported yet" # need to divide bias by scales. def forward(self, x): return super().forward(x) * self.output_scales # Works for CPU but not CUDA. # Starting point if you need to add support for bias. # def _load_from_state_dict(self, *args, **kwargs): # # Seems like transformers doesn't call load_state_dict. # # args[0] - state_dict # # args[1] - prefix # args[0][f"{args[1]}output_scales"] = args[0][f"{args[1]}output_scales"].t() # super()._load_from_state_dict(*args, **kwargs) # if self.bias is not None: # self.bias.data = self.bias.data / self.output_scales class LLamaNuGPTQForCausalLM(LlamaForCausalLM): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def replace_linear_modules(module): for name, mod in module.named_children(): if isinstance(mod, nn.Linear) and name in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]: setattr(module, name, ScaledLinear(mod.in_features, mod.out_features, mod.bias is not None)) else: replace_linear_modules(mod) replace_linear_modules(self)