0xrushi commited on
Commit
bc32204
·
1 Parent(s): bcc2f8f

first commit

Browse files
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+ import transformers
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.cuda.amp import custom_fwd, custom_bwd
9
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
10
+
11
+
12
+ class FrozenBNBLinear(nn.Module):
13
+ def __init__(self, weight, absmax, code, bias=None):
14
+ assert isinstance(bias, nn.Parameter) or bias is None
15
+ super().__init__()
16
+ self.out_features, self.in_features = weight.shape
17
+ self.register_buffer("weight", weight.requires_grad_(False))
18
+ self.register_buffer("absmax", absmax.requires_grad_(False))
19
+ self.register_buffer("code", code.requires_grad_(False))
20
+ self.adapter = None
21
+ self.bias = bias
22
+
23
+ def forward(self, input):
24
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
25
+ if self.adapter:
26
+ output += self.adapter(input)
27
+ return output
28
+
29
+ @classmethod
30
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
31
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
32
+ return cls(weights_int8, *state, linear.bias)
33
+
34
+ def __repr__(self):
35
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
36
+
37
+
38
+ class DequantizeAndLinear(torch.autograd.Function):
39
+ @staticmethod
40
+ @custom_fwd
41
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
42
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
43
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
44
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
45
+ ctx._has_bias = bias is not None
46
+ return F.linear(input, weights_deq, bias)
47
+
48
+ @staticmethod
49
+ @custom_bwd
50
+ def backward(ctx, grad_output: torch.Tensor):
51
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
52
+ input, weights_quantized, absmax, code = ctx.saved_tensors
53
+ # grad_output: [*batch, out_features]
54
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
55
+ grad_input = grad_output @ weights_deq
56
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
57
+ return grad_input, None, None, None, grad_bias
58
+
59
+
60
+ class FrozenBNBEmbedding(nn.Module):
61
+ def __init__(self, weight, absmax, code):
62
+ super().__init__()
63
+ self.num_embeddings, self.embedding_dim = weight.shape
64
+ self.register_buffer("weight", weight.requires_grad_(False))
65
+ self.register_buffer("absmax", absmax.requires_grad_(False))
66
+ self.register_buffer("code", code.requires_grad_(False))
67
+ self.adapter = None
68
+
69
+ def forward(self, input, **kwargs):
70
+ with torch.no_grad():
71
+ # note: both quantuized weights and input indices are *not* differentiable
72
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
73
+ output = F.embedding(input, weight_deq, **kwargs)
74
+ if self.adapter:
75
+ output += self.adapter(input)
76
+ return output
77
+
78
+ @classmethod
79
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
80
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
81
+ return cls(weights_int8, *state)
82
+
83
+ def __repr__(self):
84
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
85
+
86
+
87
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
88
+ assert chunk_size % 4096 == 0
89
+ code = None
90
+ chunks = []
91
+ absmaxes = []
92
+ flat_tensor = matrix.view(-1)
93
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
94
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
95
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
96
+ chunks.append(quantized_chunk)
97
+ absmaxes.append(absmax_chunk)
98
+
99
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
100
+ absmax = torch.cat(absmaxes)
101
+ return matrix_i8, (absmax, code)
102
+
103
+
104
+ def convert_to_int8(model):
105
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
106
+ for module in list(model.modules()):
107
+ for name, child in module.named_children():
108
+ if isinstance(child, nn.Linear):
109
+ print(name, child)
110
+ setattr(
111
+ module,
112
+ name,
113
+ FrozenBNBLinear(
114
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
115
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
116
+ code=torch.zeros(256),
117
+ bias=child.bias,
118
+ ),
119
+ )
120
+ elif isinstance(child, nn.Embedding):
121
+ setattr(
122
+ module,
123
+ name,
124
+ FrozenBNBEmbedding(
125
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
126
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
127
+ code=torch.zeros(256),
128
+ )
129
+ )
130
+
131
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
132
+ def __init__(self, config):
133
+ super().__init__(config)
134
+
135
+ convert_to_int8(self.attn)
136
+ convert_to_int8(self.mlp)
137
+
138
+
139
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ convert_to_int8(self)
143
+
144
+
145
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
146
+ def __init__(self, config):
147
+ super().__init__(config)
148
+ convert_to_int8(self)
149
+
150
+ class T5ForConditionalGeneration(transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
151
+ def __init__(self, config):
152
+ super().__init__(config)
153
+ convert_to_int8(self)
154
+
155
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock
156
+ transformers.models.t5.modeling_t5.T5ForConditionalGeneration = T5ForConditionalGeneration
157
+
158
+ config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
159
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
160
+
161
+ config.pad_token_id = config.eos_token_id
162
+ tokenizer.pad_token = config.pad_token_id
163
+
164
+ gpt = GPTJForCausalLM(config)#.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)
165
+
166
+ def add_adapters(model, adapter_dim=4, p = 0.1):
167
+ assert adapter_dim > 0
168
+
169
+ for name, module in model.named_modules():
170
+ if isinstance(module, FrozenBNBLinear):
171
+ if "attn" in name or "mlp" in name or "head" in name:
172
+ print("Adding adapter to", name)
173
+ module.adapter = nn.Sequential(
174
+ nn.Linear(module.in_features, adapter_dim, bias=False),
175
+ nn.Dropout(p=p),
176
+ nn.Linear(adapter_dim, module.out_features, bias=False),
177
+ )
178
+ print("Initializing", name)
179
+ nn.init.zeros_(module.adapter[2].weight)
180
+
181
+ else:
182
+ print("Not adding adapter to", name)
183
+ elif isinstance(module, FrozenBNBEmbedding):
184
+ print("Adding adapter to", name)
185
+ module.adapter = nn.Sequential(
186
+ nn.Embedding(module.num_embeddings, adapter_dim),
187
+ nn.Dropout(p=p),
188
+ nn.Linear(adapter_dim, module.embedding_dim, bias=False),
189
+ )
190
+ print("Initializing", name)
191
+ nn.init.zeros_(module.adapter[2].weight)
192
+
193
+ add_adapters(gpt)
194
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
195
+ gpt.to(device)
196
+ if device == 'cpu':
197
+ gpt.load_state_dict(torch.load('rewrite_and_paraphrase_pretrained_gptj8bit.pt', map_location=torch.device('cpu')))
198
+ else:
199
+ gpt.load_state_dict(torch.load('rewrite_and_paraphrase_pretrained_gptj8bit.pt'))
200
+ gpt.eval()
201
+
202
+ def inference(text):
203
+ with torch.no_grad():
204
+ prompt = tokenizer(text, truncation=True, padding=True, max_length=128, return_tensors='pt')
205
+ prompt = {key: value for key, value in prompt.items()}
206
+ out = gpt.generate(**prompt, max_length=512, top_k=50, top_p=0.9, temperature=1.0, do_sample=True, repetition_penalty = 1.2, num_beams=1)
207
+ return tokenizer.decode(out[0])
208
+
209
+
210
+ iface = gr.Interface(fn=inference, inputs="text", outputs="text")
211
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers==4.14.1
4
+ bitsandbytes-cuda111==0.26.0
5
+ datasets==1.16.1
rewrite_and_paraphrase_pretrained_gptj8bit.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03b1fa07169c705ef7e9f68eccf4e3f7050a8bd12316d754fd5e25f0fb351f9b
3
+ size 6231829230