mrm8488 commited on
Commit
b713267
1 Parent(s): f276499

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +174 -0
README.md CHANGED
@@ -51,3 +51,177 @@ You can train fine in colab, but if you get a K80, it's probably best to switch
51
 
52
  The model was converted using [this notebook](https://nbviewer.org/urls/huggingface.co/hivemind/gpt-j-6B-8bit/raw/main/convert-gpt-j.ipynb). It can be adapted to work with other model types. However, please bear in mind that some models replace Linear and Embedding with custom alternatives that require their own BNBWhateverWithAdapters.
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  The model was converted using [this notebook](https://nbviewer.org/urls/huggingface.co/hivemind/gpt-j-6B-8bit/raw/main/convert-gpt-j.ipynb). It can be adapted to work with other model types. However, please bear in mind that some models replace Linear and Embedding with custom alternatives that require their own BNBWhateverWithAdapters.
53
 
54
+ ### How to use
55
+
56
+ ```sh
57
+ !pip install transformers==4.14.1
58
+ !pip install bitsandbytes-cuda111==0.26.0
59
+ !pip install datasets==1.16.1
60
+ ```
61
+
62
+ ```py
63
+ import transformers
64
+
65
+ import torch
66
+ import torch.nn.functional as F
67
+ from torch import nn
68
+ from torch.cuda.amp import custom_fwd, custom_bwd
69
+
70
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
71
+
72
+ from tqdm.auto import tqdm
73
+
74
+ class FrozenBNBLinear(nn.Module):
75
+ def __init__(self, weight, absmax, code, bias=None):
76
+ assert isinstance(bias, nn.Parameter) or bias is None
77
+ super().__init__()
78
+ self.out_features, self.in_features = weight.shape
79
+ self.register_buffer("weight", weight.requires_grad_(False))
80
+ self.register_buffer("absmax", absmax.requires_grad_(False))
81
+ self.register_buffer("code", code.requires_grad_(False))
82
+ self.adapter = None
83
+ self.bias = bias
84
+
85
+ def forward(self, input):
86
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
87
+ if self.adapter:
88
+ output += self.adapter(input)
89
+ return output
90
+
91
+ @classmethod
92
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
93
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
94
+ return cls(weights_int8, *state, linear.bias)
95
+
96
+ def __repr__(self):
97
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
98
+
99
+
100
+ class DequantizeAndLinear(torch.autograd.Function):
101
+ @staticmethod
102
+ @custom_fwd
103
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
104
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
105
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
106
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
107
+ ctx._has_bias = bias is not None
108
+ return F.linear(input, weights_deq, bias)
109
+
110
+ @staticmethod
111
+ @custom_bwd
112
+ def backward(ctx, grad_output: torch.Tensor):
113
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
114
+ input, weights_quantized, absmax, code = ctx.saved_tensors
115
+ # grad_output: [*batch, out_features]
116
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
117
+ grad_input = grad_output @ weights_deq
118
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
119
+ return grad_input, None, None, None, grad_bias
120
+
121
+
122
+ class FrozenBNBEmbedding(nn.Module):
123
+ def __init__(self, weight, absmax, code):
124
+ super().__init__()
125
+ self.num_embeddings, self.embedding_dim = weight.shape
126
+ self.register_buffer("weight", weight.requires_grad_(False))
127
+ self.register_buffer("absmax", absmax.requires_grad_(False))
128
+ self.register_buffer("code", code.requires_grad_(False))
129
+ self.adapter = None
130
+
131
+ def forward(self, input, **kwargs):
132
+ with torch.no_grad():
133
+ # note: both quantuized weights and input indices are *not* differentiable
134
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
135
+ output = F.embedding(input, weight_deq, **kwargs)
136
+ if self.adapter:
137
+ output += self.adapter(input)
138
+ return output
139
+
140
+ @classmethod
141
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
142
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
143
+ return cls(weights_int8, *state)
144
+
145
+ def __repr__(self):
146
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
147
+
148
+
149
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
150
+ assert chunk_size % 4096 == 0
151
+ code = None
152
+ chunks = []
153
+ absmaxes = []
154
+ flat_tensor = matrix.view(-1)
155
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
156
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
157
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
158
+ chunks.append(quantized_chunk)
159
+ absmaxes.append(absmax_chunk)
160
+
161
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
162
+ absmax = torch.cat(absmaxes)
163
+ return matrix_i8, (absmax, code)
164
+
165
+
166
+ def convert_to_int8(model):
167
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
168
+ for module in list(model.modules()):
169
+ for name, child in module.named_children():
170
+ if isinstance(child, nn.Linear):
171
+ print(name, child)
172
+ setattr(
173
+ module,
174
+ name,
175
+ FrozenBNBLinear(
176
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
177
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
178
+ code=torch.zeros(256),
179
+ bias=child.bias,
180
+ ),
181
+ )
182
+ elif isinstance(child, nn.Embedding):
183
+ setattr(
184
+ module,
185
+ name,
186
+ FrozenBNBEmbedding(
187
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
188
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
189
+ code=torch.zeros(256),
190
+ )
191
+ )
192
+
193
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
194
+ def __init__(self, config):
195
+ super().__init__(config)
196
+
197
+ convert_to_int8(self.attn)
198
+ convert_to_int8(self.mlp)
199
+
200
+
201
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
202
+ def __init__(self, config):
203
+ super().__init__(config)
204
+ convert_to_int8(self)
205
+
206
+
207
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+ convert_to_int8(self)
211
+
212
+
213
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
214
+
215
+ config = transformers.GPTJConfig.from_pretrained("mrm8488/bertin-gpt-j-6B-ES-8bit")
216
+ tokenizer = transformers.AutoTokenizer.from_pretrained("mrm8488/bertin-gpt-j-6B-ES-8bit")
217
+
218
+ gpt = GPTJForCausalLM.from_pretrained("mrm8488/bertin-gpt-j-6B-ES-8bit", low_cpu_mem_usage=True)
219
+
220
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
221
+ gpt.to(device)
222
+
223
+ prompt = tokenizer("El sentido de la vida es", return_tensors='pt')
224
+ prompt = {key: value.to(device) for key, value in prompt.items()}
225
+ out = gpt.generate(**prompt, max_length=64, do_sample=True)
226
+ print(tokenizer.decode(out[0]))
227
+ ```