malenia1 commited on
Commit
ad5b231
1 Parent(s): e5f3910

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +107 -0
modeling.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ["PATH"]
4
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
5
+
6
+ import bitblas
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import BertConfig, BertModel, PreTrainedModel, PretrainedConfig,AutoModel,AutoConfig,BertPreTrainedModel
10
+
11
+ class bitlinear(bitblas.Linear):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ out_features: int,
16
+ bias: bool = False,
17
+ A_dtype: str = "float16",
18
+ W_dtype: str = "int2",
19
+ accum_dtype: str = "float16",
20
+ out_dtype: str = "float16",
21
+ group_size: int = -1,
22
+ with_scaling: bool = False,
23
+ with_zeros: bool = False,
24
+ zeros_mode: str = None,
25
+ opt_M: list = [1, 16, 32, 64, 128, 256, 512],
26
+ fast_decoding: bool = True,
27
+ alpha: torch.dtype = torch.float16,
28
+ b:torch.Tensor=None
29
+ ):
30
+ super().__init__(
31
+ in_features=in_features,
32
+ out_features=out_features,
33
+ bias=bias,
34
+ A_dtype=A_dtype,
35
+ W_dtype=W_dtype,
36
+ accum_dtype=accum_dtype,
37
+ out_dtype=out_dtype,
38
+ group_size=group_size,
39
+ with_scaling=with_scaling,
40
+ with_zeros=with_zeros,
41
+ zeros_mode=zeros_mode,
42
+ opt_M=opt_M,
43
+ fast_decoding=fast_decoding,
44
+ )
45
+ self.alpha = nn.Parameter(alpha,requires_grad=False)
46
+ self.b = nn.Parameter(b,requires_grad=False)
47
+
48
+ def forward(self, A: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
49
+ out = super().forward(A, out)
50
+ out *= self.alpha
51
+ if self.b is not None:
52
+ out += self.b.view(1, -1).expand_as(out)
53
+ return out.to(torch.float32)
54
+
55
+
56
+ class TernaryBertConfig(BertConfig):
57
+ model_type = "ternarybert"
58
+ def __init__(self, **kwargs):
59
+ super().__init__(**kwargs)
60
+
61
+
62
+ class TernaryBert(PreTrainedModel):
63
+ #config_class = TernaryBertConfig
64
+ config_class = BertConfig
65
+
66
+ def __init__(self, config):
67
+ super().__init__(config)
68
+ self.bert = BertModel(config)
69
+ self.replace_linear2bitblas(self.bert)
70
+
71
+ #def forward(self, input_ids, attention_mask=None,token_type_ids=None):
72
+ # return self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
73
+ def forward(self, **kwargs):
74
+ return self.bert(**kwargs)
75
+
76
+ def convert_to_bitlinear(self,layer):
77
+ bitlayer = bitlinear(
78
+ in_features=layer.in_features,
79
+ out_features=layer.out_features,
80
+ bias=False,
81
+ A_dtype="float16", # activation A dtype
82
+ W_dtype="int2", # weight W dtype
83
+ accum_dtype="float16", # accumulation dtype
84
+ out_dtype="float16", # output dtype
85
+ # configs for weight only quantization
86
+ group_size=-1, # setting for grouped quantization
87
+ with_scaling=False, # setting for scaling factor
88
+ with_zeros=False, # setting for zeros
89
+ zeros_mode=None, # setting for how to calculating zeros
90
+ # Target optimization var for dynamic symbolic.
91
+ # For detailed information please checkout docs/PythonAPI.md
92
+ # By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
93
+ opt_M=[1, 16, 32, 64, 128, 256, 512],
94
+ fast_decoding=True,
95
+ alpha=torch.tensor(1.).to(torch.float16),
96
+ b = layer.bias.data.to(torch.float16)
97
+ )
98
+ return bitlayer
99
+
100
+ def replace_linear2bitblas(self,model):
101
+ for name, module in model.named_children():
102
+ if isinstance(module, nn.Linear):
103
+ new_layer = self.convert_to_bitlinear(module)
104
+ setattr(model, name, new_layer)
105
+ elif len(list(module.children())) > 0:
106
+ self.replace_linear2bitblas(module)
107
+