Text Generation
Transformers
PyTorch
bloom
text-generation-inference
mrm8488 commited on
Commit
8b55706
1 Parent(s): 22c0f1b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +222 -1
README.md CHANGED
@@ -1,3 +1,224 @@
1
  ---
2
- license: wtfpl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ inference: false
3
+ license: bigscience-bloom-rail-1.0
4
+ language:
5
+ - ak
6
+ - ar
7
+ - as
8
+ - bm
9
+ - bn
10
+ - ca
11
+ - en
12
+ - es
13
+ - eu
14
+ - fon
15
+ - fr
16
+ - gu
17
+ - hi
18
+ - id
19
+ - ig
20
+ - ki
21
+ - kn
22
+ - lg
23
+ - ln
24
+ - ml
25
+ - mr
26
+ - ne
27
+ - nso
28
+ - ny
29
+ - or
30
+ - pa
31
+ - pt
32
+ - rn
33
+ - rw
34
+ - sn
35
+ - st
36
+ - sw
37
+ - ta
38
+ - te
39
+ - tn
40
+ - ts
41
+ - tum
42
+ - tw
43
+ - ur
44
+ - vi
45
+ - wo
46
+ - xh
47
+ - yo
48
+ - zh
49
+ - zu
50
+
51
+ pipeline_tag: text-generation
52
+
53
  ---
54
+ ### Quantized bigscience/bloom 1B3 with 8-bit weights
55
+
56
+ Heavily inspired by [Hivemind's GPT-J-6B with 8-bit weights](https://huggingface.co/hivemind/gpt-j-6B-8bit), this is a version of [bigscience/bloom](https://huggingface.co/bigscience/bloom-1b3) a ~1 billion parameters language model that you run and fine-tune with less memory.
57
+
58
+ Here, we also apply [LoRA (Low Rank Adaptation)](https://arxiv.org/abs/2106.09685) to reduce model size.
59
+
60
+ ### How to fine-tune
61
+ TBA
62
+
63
+ ### How to use
64
+
65
+ This model can be used by adapting Bloom original implementation. This is an adaptation from [Hivemind's GPT-J 8-bit](https://nbviewer.org/urls/huggingface.co/hivemind/gpt-j-6B-8bit/raw/main/convert-gpt-j.ipynb):
66
+
67
+ ```python
68
+ import transformers
69
+ import torch
70
+ import torch.nn as nn
71
+ import torch.nn.functional as F
72
+
73
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
74
+ from typing import Tuple
75
+ from torch.cuda.amp import custom_fwd, custom_bwd
76
+
77
+ class FrozenBNBLinear(nn.Module):
78
+ def __init__(self, weight, absmax, code, bias=None):
79
+ assert isinstance(bias, nn.Parameter) or bias is None
80
+ super().__init__()
81
+ self.out_features, self.in_features = weight.shape
82
+ self.register_buffer("weight", weight.requires_grad_(False))
83
+ self.register_buffer("absmax", absmax.requires_grad_(False))
84
+ self.register_buffer("code", code.requires_grad_(False))
85
+ self.adapter = None
86
+ self.bias = bias
87
+
88
+ def forward(self, input):
89
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
90
+ if self.adapter:
91
+ output += self.adapter(input)
92
+ return output
93
+
94
+ @classmethod
95
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
96
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
97
+ return cls(weights_int8, *state, linear.bias)
98
+
99
+ def __repr__(self):
100
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
101
+
102
+
103
+ class DequantizeAndLinear(torch.autograd.Function):
104
+ @staticmethod
105
+ @custom_fwd
106
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
107
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
108
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
109
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
110
+ ctx._has_bias = bias is not None
111
+ return F.linear(input, weights_deq, bias)
112
+
113
+ @staticmethod
114
+ @custom_bwd
115
+ def backward(ctx, grad_output: torch.Tensor):
116
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
117
+ input, weights_quantized, absmax, code = ctx.saved_tensors
118
+ # grad_output: [*batch, out_features]
119
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
120
+ grad_input = grad_output @ weights_deq
121
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
122
+ return grad_input, None, None, None, grad_bias
123
+
124
+
125
+ class FrozenBNBEmbedding(nn.Module):
126
+ def __init__(self, weight, absmax, code):
127
+ super().__init__()
128
+ self.num_embeddings, self.embedding_dim = weight.shape
129
+ self.register_buffer("weight", weight.requires_grad_(False))
130
+ self.register_buffer("absmax", absmax.requires_grad_(False))
131
+ self.register_buffer("code", code.requires_grad_(False))
132
+ self.adapter = None
133
+
134
+ def forward(self, input, **kwargs):
135
+ with torch.no_grad():
136
+ # note: both quantuized weights and input indices are *not* differentiable
137
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
138
+ output = F.embedding(input, weight_deq, **kwargs)
139
+ if self.adapter:
140
+ output += self.adapter(input)
141
+ return output
142
+
143
+ @classmethod
144
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
145
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
146
+ return cls(weights_int8, *state)
147
+
148
+ def __repr__(self):
149
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
150
+
151
+
152
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
153
+ assert chunk_size % 4096 == 0
154
+ code = None
155
+ chunks = []
156
+ absmaxes = []
157
+ flat_tensor = matrix.view(-1)
158
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
159
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
160
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
161
+ chunks.append(quantized_chunk)
162
+ absmaxes.append(absmax_chunk)
163
+
164
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
165
+ absmax = torch.cat(absmaxes)
166
+ return matrix_i8, (absmax, code)
167
+
168
+
169
+ def convert_to_int8(model):
170
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
171
+ for module in list(model.modules()):
172
+ for name, child in module.named_children():
173
+ if isinstance(child, nn.Linear):
174
+ print(name, child)
175
+ setattr(
176
+ module,
177
+ name,
178
+ FrozenBNBLinear(
179
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
180
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
181
+ code=torch.zeros(256),
182
+ bias=child.bias,
183
+ ),
184
+ )
185
+ elif isinstance(child, nn.Embedding):
186
+ setattr(
187
+ module,
188
+ name,
189
+ FrozenBNBEmbedding(
190
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
191
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
192
+ code=torch.zeros(256),
193
+ )
194
+ )
195
+
196
+ class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):
197
+ def __init__(self, config, layer_number=None):
198
+ super().__init__(config, layer_number)
199
+
200
+ convert_to_int8(self.self_attention)
201
+ convert_to_int8(self.mlp)
202
+
203
+
204
+ class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):
205
+ def __init__(self, config):
206
+ super().__init__(config)
207
+ convert_to_int8(self)
208
+
209
+
210
+ class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):
211
+ def __init__(self, config):
212
+ super().__init__(config)
213
+ convert_to_int8(self)
214
+
215
+ transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock
216
+
217
+ model_name = 'mrm8488/bloom-1b3-8bit'
218
+ model = BloomForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
219
+ tokenizer = BloomTokenizerFast.from_pretrained(model_name)
220
+
221
+ prompt = tokenizer("Given a table named salaries and columns id, created_at, salary, age. Creates a SQL to answer What is the average salary for 22 years old:", return_tensors='pt')
222
+ out = model.generate(**prompt, min_length=10, do_sample=True)
223
+ tokenizer.decode(out[0])
224
+ ```