import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer import re import transformers class ReplacedLinearLayer(nn.Module): def __init__(self, input_dim, output_dim, if_conv=True): super().__init__() self.register_buffer('weights', torch.zeros([output_dim, input_dim], dtype=torch.int8)) self.register_buffer('scale_matrix', torch.zeros(output_dim, dtype=torch.int8)) # self.register_buffer("bias", torch.zeros((1, output_dim), dtype = torch.float32)) self.bias = None self.if_conv = if_conv def forward(self, x): fp32_weights = self.weights.to(x.dtype) # print(fp32_weights.shape, self.scales.shape, ) try: x = F.linear(x, fp32_weights )* self.scales if self.bias is not None: x += self.bias except Exception as e: print(e) print(fp32_weights.shape, self.scales.shape, ) exit() return x def do_quantization(self, W, ): if self.if_conv: W32 = W.clone().squeeze().T else: W32 = W.clone() scales = (torch.max(W32.abs(), dim=-1)[0]/127).to(torch.float32) self.scales = scales self.weights = torch.round(W32 / scales[:, None]).to(torch.int8) def perform_quantization(module, regex='.*'): pattern = re.compile(regex) for name, node in module.named_modules(): for name2, child in node.named_children(): if ( isinstance(child, nn.Linear) or isinstance(child, transformers.pytorch_utils.Conv1D) ) and pattern.match(f'{name}.{name2}'): # print(name, name2, node, child) fp32_weight, fp32_bias = child.weight, child.bias quant_module = ReplacedLinearLayer(child.weight.shape[1], child.weight.shape[0], if_conv=isinstance(child, transformers.pytorch_utils.Conv1D)) setattr(node, name2, quant_module) # print(getattr(node, name2).custom_weights) # return getattr(node, name2).do_quantization(fp32_weight) if fp32_bias is not None: getattr(node, name2).bias = fp32_bias # print(getattr(node, name2).weights) # return