File size: 1,588 Bytes
c2fa56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from transformers import LlamaForCausalLM
import torch
from torch import nn

class ScaledLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias=bias)
        self.output_scales = nn.Parameter(torch.ones((1, out_features)))
        assert bias == False, "bias not supported yet" # need to divide bias by scales.

    def forward(self, x):
        return super().forward(x) * self.output_scales

    # Works for CPU but not CUDA.
    # Starting point if you need to add support for bias.
    # def _load_from_state_dict(self, *args, **kwargs):
    #     # Seems like transformers doesn't call load_state_dict.
    #     # args[0] - state_dict
    #     # args[1] - prefix
    #     args[0][f"{args[1]}output_scales"] = args[0][f"{args[1]}output_scales"].t()
    #     super()._load_from_state_dict(*args, **kwargs)
    #     if self.bias is not None:
    #         self.bias.data = self.bias.data / self.output_scales

class LLamaNuGPTQForCausalLM(LlamaForCausalLM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        def replace_linear_modules(module):
            for name, mod in module.named_children():
                if isinstance(mod, nn.Linear) and name in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "v_proj", "o_proj"]:
                    setattr(module, name, ScaledLinear(mod.in_features, mod.out_features, mod.bias is not None))
                else:
                    replace_linear_modules(mod)
        replace_linear_modules(self)