File size: 2,388 Bytes
35e23cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch
import torch.nn as nn 
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import transformers

class ReplacedLinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, if_conv=True):
        super().__init__()

        self.register_buffer('weights', torch.zeros([output_dim, input_dim], dtype=torch.int8))
        self.register_buffer('scale_matrix',  torch.zeros(output_dim, dtype=torch.int8))

        # self.register_buffer("bias", torch.zeros((1, output_dim), dtype = torch.float32))
        self.bias = None
        self.if_conv = if_conv

    def forward(self, x):
        fp32_weights = self.weights.to(x.dtype)
        # print(fp32_weights.shape, self.scales.shape, )
        try:
            x = F.linear(x, fp32_weights )* self.scales
            if self.bias is not None:
                x += self.bias
        except Exception as e:
            print(e)
            print(fp32_weights.shape, self.scales.shape, )
            
            exit()
        return x
    
    def do_quantization(self, W, ):
        if self.if_conv:
           W32 = W.clone().squeeze().T
        else: 
            W32 = W.clone()

        scales = (torch.max(W32.abs(), dim=-1)[0]/127).to(torch.float32)
        self.scales = scales
        self.weights = torch.round(W32 / scales[:, None]).to(torch.int8)


def perform_quantization(module,  regex='.*'):
    pattern = re.compile(regex)
    for name, node in module.named_modules():
        for name2, child in node.named_children():
            if ( isinstance(child, nn.Linear) or  isinstance(child, transformers.pytorch_utils.Conv1D) ) and pattern.match(f'{name}.{name2}'):
                # print(name, name2, node, child)
                fp32_weight, fp32_bias = child.weight, child.bias

                quant_module = ReplacedLinearLayer(child.weight.shape[1], child.weight.shape[0], if_conv=isinstance(child, transformers.pytorch_utils.Conv1D))
                setattr(node, name2, quant_module)
                # print(getattr(node, name2).custom_weights)
                # return
                getattr(node, name2).do_quantization(fp32_weight)
                if fp32_bias is not None:
                    getattr(node, name2).bias = fp32_bias
                # print(getattr(node, name2).weights)

                # return