justheuristic commited on
Commit
08ba7c1
·
1 Parent(s): 0784a51

Create check_perplexity.ipynb

Browse files
Files changed (1) hide show
  1. check_perplexity.ipynb +691 -0
check_perplexity.ipynb ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### Original GPT-J perlexity"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import torch\n",
17
+ "import torch.nn as nn\n",
18
+ "import torch.nn.functional as F\n",
19
+ "\n",
20
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
21
+ "import transformers\n",
22
+ "from tqdm.auto import tqdm\n",
23
+ "\n",
24
+ "\n",
25
+ "\n",
26
+ "model_name = \"EleutherAI/gpt-j-6B\"\n",
27
+ "gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)\n",
28
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 11,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "device = 'cuda' if torch.cuda.is_available else 'cpu'\n",
38
+ "gpt.to(device).train(False);"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 4,
44
+ "metadata": {},
45
+ "outputs": [
46
+ {
47
+ "name": "stderr",
48
+ "output_type": "stream",
49
+ "text": [
50
+ "Reusing dataset wikitext (/home/jheuristic/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)\n"
51
+ ]
52
+ },
53
+ {
54
+ "data": {
55
+ "application/vnd.jupyter.widget-view+json": {
56
+ "model_id": "47f0459174da4ee2bf064c9ae81fdecd",
57
+ "version_major": 2,
58
+ "version_minor": 0
59
+ },
60
+ "text/plain": [
61
+ " 0%| | 0/3 [00:00<?, ?it/s]"
62
+ ]
63
+ },
64
+ "metadata": {},
65
+ "output_type": "display_data"
66
+ }
67
+ ],
68
+ "source": [
69
+ "from datasets import load_dataset\n",
70
+ "data = load_dataset('wikitext', 'wikitext-2-v1')['test']"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 62,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "data": {
80
+ "application/vnd.jupyter.widget-view+json": {
81
+ "model_id": "26cca02205624aafa740e55542ca2e6c",
82
+ "version_major": 2,
83
+ "version_minor": 0
84
+ },
85
+ "text/plain": [
86
+ " 0%| | 0/4358 [00:00<?, ?it/s]"
87
+ ]
88
+ },
89
+ "metadata": {},
90
+ "output_type": "display_data"
91
+ }
92
+ ],
93
+ "source": [
94
+ "\n",
95
+ "numerator, denominator = 0, 0\n",
96
+ "collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
97
+ "loader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=0, shuffle=False)\n",
98
+ "\n",
99
+ "\n",
100
+ "with torch.no_grad(), torch.cuda.amp.autocast(), tqdm(loader) as progressbar:\n",
101
+ " for i, row in enumerate(progressbar):\n",
102
+ " if max(map(len, row['text'])) <= 1:\n",
103
+ " continue\n",
104
+ " batch = tokenizer(**row, truncation=False, return_tensors='pt')\n",
105
+ " batch = {k: v.cuda() for k, v in batch.items()}\n",
106
+ "\n",
107
+ " out = gpt.forward(**batch,)\n",
108
+ "\n",
109
+ " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
110
+ " reduction='none')\n",
111
+ "\n",
112
+ " numerator += loss.sum().item()\n",
113
+ " denominator += len(loss)\n",
114
+ " progressbar.desc = f\"{numerator/denominator:.3f}\""
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 63,
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "data": {
124
+ "text/plain": [
125
+ "18.435175441788164"
126
+ ]
127
+ },
128
+ "execution_count": 63,
129
+ "metadata": {},
130
+ "output_type": "execute_result"
131
+ }
132
+ ],
133
+ "source": [
134
+ "# test perplexity\n",
135
+ "import math\n",
136
+ "math.exp(numerator/denominator)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "### Quantized GPT-J Perplexity"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 64,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "\n",
153
+ "import torch\n",
154
+ "import torch.nn as nn\n",
155
+ "from torch.cuda.amp import custom_fwd, custom_bwd\n",
156
+ " \n",
157
+ "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n",
158
+ "import transformers\n",
159
+ "\n",
160
+ "\n",
161
+ "class DequantizeAndLinear(torch.autograd.Function):\n",
162
+ " \n",
163
+ " @staticmethod\n",
164
+ " @custom_fwd\n",
165
+ " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n",
166
+ " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n",
167
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
168
+ " ctx.save_for_backward(input, weights_quantized, absmax, code)\n",
169
+ " ctx._has_bias = bias is not None\n",
170
+ " return F.linear(input, weights_deq, bias)\n",
171
+ " \n",
172
+ " @staticmethod\n",
173
+ " @custom_bwd\n",
174
+ " def backward(ctx, grad_output: torch.Tensor):\n",
175
+ " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n",
176
+ " input, weights_quantized, absmax, code = ctx.saved_tensors\n",
177
+ " # grad_output: [*batch, out_features]\n",
178
+ " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n",
179
+ " grad_input = grad_output @ weights_deq\n",
180
+ " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n",
181
+ " return grad_input, None, None, None, grad_bias\n",
182
+ "\n",
183
+ "\n",
184
+ "class FrozenBNBLinear(nn.Module):\n",
185
+ " def __init__(self, weight, absmax, code, bias=None):\n",
186
+ " assert isinstance(bias, nn.Parameter) or bias is None\n",
187
+ " super().__init__()\n",
188
+ " self.out_features, self.in_features = weight.shape\n",
189
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
190
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
191
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
192
+ " self.bias = bias\n",
193
+ " \n",
194
+ " def forward(self, input):\n",
195
+ " return DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n",
196
+ " \n",
197
+ " @classmethod\n",
198
+ " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n",
199
+ " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n",
200
+ " return cls(weights_int8, *state, linear.bias)\n",
201
+ " \n",
202
+ " def __repr__(self):\n",
203
+ " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n",
204
+ " \n",
205
+ " \n",
206
+ "class FrozenBNBEmbedding(nn.Module):\n",
207
+ " def __init__(self, weight, absmax, code):\n",
208
+ " super().__init__()\n",
209
+ " self.num_embeddings, self.embedding_dim = weight.shape\n",
210
+ " self.register_buffer(\"weight\", weight.requires_grad_(False))\n",
211
+ " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n",
212
+ " self.register_buffer(\"code\", code.requires_grad_(False))\n",
213
+ " \n",
214
+ " def forward(self, x, **kwargs):\n",
215
+ " with torch.no_grad():\n",
216
+ " # note: both quantuized weights and input indices are *not* differentiable\n",
217
+ " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n",
218
+ " return F.embedding(x, weight_deq, **kwargs)\n",
219
+ " \n",
220
+ " @classmethod\n",
221
+ " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n",
222
+ " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n",
223
+ " return cls(weights_int8, *state)\n",
224
+ " \n",
225
+ " def __repr__(self):\n",
226
+ " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n",
227
+ " \n",
228
+ " \n",
229
+ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n",
230
+ " assert chunk_size % 4096 == 0\n",
231
+ " code = None\n",
232
+ " chunks = []\n",
233
+ " absmaxes = []\n",
234
+ " flat_tensor = matrix.view(-1)\n",
235
+ " for i in range((matrix.numel() - 1) // chunk_size + 1):\n",
236
+ " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n",
237
+ " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n",
238
+ " chunks.append(quantized_chunk)\n",
239
+ " absmaxes.append(absmax_chunk)\n",
240
+ " \n",
241
+ " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n",
242
+ " absmax = torch.cat(absmaxes)\n",
243
+ " return matrix_i8, (absmax, code)\n",
244
+ "\n",
245
+ "\n",
246
+ "def dummify(model, adapter_dim: int = 0):\n",
247
+ " for module in list(model.modules()):\n",
248
+ " for name, child in module.named_children():\n",
249
+ " if isinstance(child, nn.Linear):\n",
250
+ " print(name, child)\n",
251
+ " setattr(\n",
252
+ " module,\n",
253
+ " name,\n",
254
+ " FrozenBNBLinear(\n",
255
+ " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n",
256
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
257
+ " code=torch.zeros(256),\n",
258
+ " bias=child.bias,\n",
259
+ " ),\n",
260
+ " )\n",
261
+ " elif isinstance(child, nn.Embedding):\n",
262
+ " setattr(\n",
263
+ " module,\n",
264
+ " name,\n",
265
+ " FrozenBNBEmbedding(\n",
266
+ " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n",
267
+ " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n",
268
+ " code=torch.zeros(256),\n",
269
+ " )\n",
270
+ " ),\n",
271
+ "\n",
272
+ "\n",
273
+ "def bnbfy_(model, adapter_dim: int = 0):\n",
274
+ " for module in list(model.modules()):\n",
275
+ " for name, child in module.named_children():\n",
276
+ " if isinstance(child, nn.Linear):\n",
277
+ " print(name, child)\n",
278
+ " setattr(module, name, FrozenBNBLinear.from_linear(child))\n",
279
+ " \n",
280
+ " elif isinstance(child, nn.Embedding):\n",
281
+ " print(name, child)\n",
282
+ " setattr(module, name, FrozenBNBEmbedding.from_embedding(child))"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 66,
288
+ "metadata": {},
289
+ "outputs": [],
290
+ "source": [
291
+ "class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):\n",
292
+ " def __init__(self, config):\n",
293
+ " print(\"MONKEYPATCH BLOCK\")\n",
294
+ " super().__init__(config)\n",
295
+ "\n",
296
+ " dummify(self.attn)\n",
297
+ " dummify(self.mlp)\n",
298
+ "\n",
299
+ "transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock\n",
300
+ "\n",
301
+ "\n",
302
+ "class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):\n",
303
+ " def __init__(self, config):\n",
304
+ " super().__init__(config)\n",
305
+ " dummify(self)\n",
306
+ "class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):\n",
307
+ " def __init__(self, config):\n",
308
+ " super().__init__(config)\n",
309
+ " dummify(self)\n"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 67,
315
+ "metadata": {},
316
+ "outputs": [
317
+ {
318
+ "data": {
319
+ "application/vnd.jupyter.widget-view+json": {
320
+ "model_id": "1c98b9ebbf8d44d8b0bc422d4bfce21f",
321
+ "version_major": 2,
322
+ "version_minor": 0
323
+ },
324
+ "text/plain": [
325
+ "Downloading: 0%| | 0.00/0.98k [00:00<?, ?B/s]"
326
+ ]
327
+ },
328
+ "metadata": {},
329
+ "output_type": "display_data"
330
+ },
331
+ {
332
+ "data": {
333
+ "application/vnd.jupyter.widget-view+json": {
334
+ "model_id": "04bc6b612ff146308ec0b63fc15640f8",
335
+ "version_major": 2,
336
+ "version_minor": 0
337
+ },
338
+ "text/plain": [
339
+ "Downloading: 0%| | 0.00/5.75G [00:00<?, ?B/s]"
340
+ ]
341
+ },
342
+ "metadata": {},
343
+ "output_type": "display_data"
344
+ },
345
+ {
346
+ "name": "stdout",
347
+ "output_type": "stream",
348
+ "text": [
349
+ "MONKEYPATCH BLOCK\n",
350
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
351
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
352
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
353
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
354
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
355
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
356
+ "MONKEYPATCH BLOCK\n",
357
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
358
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
359
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
360
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
361
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
362
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
363
+ "MONKEYPATCH BLOCK\n",
364
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
365
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
366
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
367
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
368
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
369
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
370
+ "MONKEYPATCH BLOCK\n",
371
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
372
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
373
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
374
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
375
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
376
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
377
+ "MONKEYPATCH BLOCK\n",
378
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
379
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
380
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
381
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
382
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
383
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
384
+ "MONKEYPATCH BLOCK\n",
385
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
386
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
387
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
388
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
389
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
390
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
391
+ "MONKEYPATCH BLOCK\n",
392
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
393
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
394
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
395
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
396
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
397
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
398
+ "MONKEYPATCH BLOCK\n",
399
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
400
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
401
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
402
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
403
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
404
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
405
+ "MONKEYPATCH BLOCK\n",
406
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
407
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
408
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
409
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
410
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
411
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
412
+ "MONKEYPATCH BLOCK\n",
413
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
414
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
415
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
416
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
417
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
418
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
419
+ "MONKEYPATCH BLOCK\n",
420
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
421
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
422
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
423
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
424
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
425
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
426
+ "MONKEYPATCH BLOCK\n",
427
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
428
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
429
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
430
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
431
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
432
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
433
+ "MONKEYPATCH BLOCK\n",
434
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
435
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
436
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
437
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
438
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
439
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
440
+ "MONKEYPATCH BLOCK\n",
441
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
442
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
443
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
444
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
445
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
446
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
447
+ "MONKEYPATCH BLOCK\n",
448
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
449
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
450
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
451
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
452
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
453
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
454
+ "MONKEYPATCH BLOCK\n",
455
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
456
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
457
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
458
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
459
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
460
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
461
+ "MONKEYPATCH BLOCK\n",
462
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
463
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
464
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
465
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
466
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
467
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
468
+ "MONKEYPATCH BLOCK\n",
469
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
470
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
471
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
472
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
473
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
474
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
475
+ "MONKEYPATCH BLOCK\n",
476
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
477
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
478
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
479
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
480
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
481
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
482
+ "MONKEYPATCH BLOCK\n",
483
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
484
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
485
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
486
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
487
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
488
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
489
+ "MONKEYPATCH BLOCK\n",
490
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
491
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
492
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
493
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
494
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
495
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
496
+ "MONKEYPATCH BLOCK\n"
497
+ ]
498
+ },
499
+ {
500
+ "name": "stdout",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
504
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
505
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
506
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
507
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
508
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
509
+ "MONKEYPATCH BLOCK\n",
510
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
511
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
512
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
513
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
514
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
515
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
516
+ "MONKEYPATCH BLOCK\n",
517
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
518
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
519
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
520
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
521
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
522
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
523
+ "MONKEYPATCH BLOCK\n",
524
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
525
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
526
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
527
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
528
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
529
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
530
+ "MONKEYPATCH BLOCK\n",
531
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
532
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
533
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
534
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
535
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
536
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
537
+ "MONKEYPATCH BLOCK\n",
538
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
539
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
540
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
541
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
542
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
543
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
544
+ "MONKEYPATCH BLOCK\n",
545
+ "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
546
+ "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
547
+ "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
548
+ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n",
549
+ "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n",
550
+ "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n",
551
+ "lm_head Linear(in_features=4096, out_features=50400, bias=True)\n"
552
+ ]
553
+ }
554
+ ],
555
+ "source": [
556
+ "config = transformers.GPTJConfig.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
557
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n",
558
+ "gpt = GPTJForCausalLM.from_pretrained(\"hivemind/gpt-j-6B-8bit\", low_cpu_mem_usage=True)"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": 68,
564
+ "metadata": {},
565
+ "outputs": [],
566
+ "source": [
567
+ "device = 'cuda' if torch.cuda.is_available else 'cpu'\n",
568
+ "gpt.to(device).train(False);"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": 69,
574
+ "metadata": {},
575
+ "outputs": [
576
+ {
577
+ "name": "stderr",
578
+ "output_type": "stream",
579
+ "text": [
580
+ "Reusing dataset wikitext (/home/jheuristic/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)\n"
581
+ ]
582
+ },
583
+ {
584
+ "data": {
585
+ "application/vnd.jupyter.widget-view+json": {
586
+ "model_id": "bfbf0e20ed194d679d2f877085f679cb",
587
+ "version_major": 2,
588
+ "version_minor": 0
589
+ },
590
+ "text/plain": [
591
+ " 0%| | 0/3 [00:00<?, ?it/s]"
592
+ ]
593
+ },
594
+ "metadata": {},
595
+ "output_type": "display_data"
596
+ }
597
+ ],
598
+ "source": [
599
+ "from datasets import load_dataset\n",
600
+ "data = load_dataset('wikitext', 'wikitext-2-v1')['test']"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": 70,
606
+ "metadata": {},
607
+ "outputs": [
608
+ {
609
+ "data": {
610
+ "application/vnd.jupyter.widget-view+json": {
611
+ "model_id": "53d7e76934de4a1498306d49e4f41ad2",
612
+ "version_major": 2,
613
+ "version_minor": 0
614
+ },
615
+ "text/plain": [
616
+ " 0%| | 0/4358 [00:00<?, ?it/s]"
617
+ ]
618
+ },
619
+ "metadata": {},
620
+ "output_type": "display_data"
621
+ }
622
+ ],
623
+ "source": [
624
+ "\n",
625
+ "numerator, denominator = 0, 0\n",
626
+ "collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
627
+ "loader = torch.utils.data.DataLoader(data, batch_size=1, num_workers=0, shuffle=False)\n",
628
+ "\n",
629
+ "\n",
630
+ "with torch.no_grad(), torch.cuda.amp.autocast(), tqdm(loader) as progressbar:\n",
631
+ " for i, row in enumerate(progressbar):\n",
632
+ " if max(map(len, row['text'])) <= 1:\n",
633
+ " continue\n",
634
+ " batch = tokenizer(**row, truncation=False, return_tensors='pt')\n",
635
+ " batch = {k: v.cuda() for k, v in batch.items()}\n",
636
+ "\n",
637
+ " out = gpt.forward(**batch,)\n",
638
+ "\n",
639
+ " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n",
640
+ " reduction='none')\n",
641
+ "\n",
642
+ " numerator += loss.sum().item()\n",
643
+ " denominator += len(loss)\n",
644
+ " progressbar.desc = f\"{numerator/denominator:.3f}\""
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": 71,
650
+ "metadata": {},
651
+ "outputs": [
652
+ {
653
+ "data": {
654
+ "text/plain": [
655
+ "18.427138288946292"
656
+ ]
657
+ },
658
+ "execution_count": 71,
659
+ "metadata": {},
660
+ "output_type": "execute_result"
661
+ }
662
+ ],
663
+ "source": [
664
+ "# test perplexity\n",
665
+ "import math\n",
666
+ "math.exp(numerator/denominator)"
667
+ ]
668
+ }
669
+ ],
670
+ "metadata": {
671
+ "kernelspec": {
672
+ "display_name": "py38",
673
+ "language": "python",
674
+ "name": "py38"
675
+ },
676
+ "language_info": {
677
+ "codemirror_mode": {
678
+ "name": "ipython",
679
+ "version": 3
680
+ },
681
+ "file_extension": ".py",
682
+ "mimetype": "text/x-python",
683
+ "name": "python",
684
+ "nbconvert_exporter": "python",
685
+ "pygments_lexer": "ipython3",
686
+ "version": "3.8.1"
687
+ }
688
+ },
689
+ "nbformat": 4,
690
+ "nbformat_minor": 2
691
+ }