matt-tries-dl commited on
Commit
2a75437
1 Parent(s): a611241
Files changed (4) hide show
  1. README.md +2 -0
  2. inf2.pkl +3 -0
  3. wikisql_inf.ipynb +0 -0
  4. wikisql_retrain.ipynb +1611 -0
README.md CHANGED
@@ -16,6 +16,8 @@ https://github.com/tloen/alpaca-lora
16
  https://huggingface.co/docs/transformers/main/en/model_doc/llama#llama
17
  https://huggingface.co/docs/transformers/index
18
  https://github.com/salesforce/WikiSQL
 
 
19
 
20
 
21
  https://arxiv.org/pdf/1910.13461.pdf
 
16
  https://huggingface.co/docs/transformers/main/en/model_doc/llama#llama
17
  https://huggingface.co/docs/transformers/index
18
  https://github.com/salesforce/WikiSQL
19
+ https://github.com/huggingface/peft
20
+
21
 
22
 
23
  https://arxiv.org/pdf/1910.13461.pdf
inf2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d423a6b369859e281c8abe3c9b9a586c941d9add5e30ef0fb0f128c819de54a
3
+ size 4305279
wikisql_inf.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
wikisql_retrain.ipynb ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "import torch.nn as nn\n",
22
+ "torch.cuda.is_available()"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "name": "stderr",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n",
35
+ " warn(msg)\n",
36
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
37
+ "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
38
+ "The class this function is called from is 'LlamaTokenizer'.\n"
39
+ ]
40
+ },
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "\n",
46
+ "===================================BUG REPORT===================================\n",
47
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
48
+ "================================================================================\n",
49
+ "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n",
50
+ "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
51
+ "CUDA SETUP: Detected CUDA version 113\n",
52
+ "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n",
53
+ "PeftConfig(peft_type='LORA', base_model_name_or_path='decapoda-research/llama-7b-hf', task_type='CASUAL_LM', inference_mode=True)\n"
54
+ ]
55
+ },
56
+ {
57
+ "data": {
58
+ "application/vnd.jupyter.widget-view+json": {
59
+ "model_id": "4581e257066e463d9638f5d3a407a70e",
60
+ "version_major": 2,
61
+ "version_minor": 0
62
+ },
63
+ "text/plain": [
64
+ "Loading checkpoint shards: 0%| | 0/33 [00:00<?, ?it/s]"
65
+ ]
66
+ },
67
+ "metadata": {},
68
+ "output_type": "display_data"
69
+ }
70
+ ],
71
+ "source": [
72
+ "from transformers import LlamaTokenizer, LlamaForCausalLM\n",
73
+ "from peft import get_peft_model, PeftConfig, PeftModel\n",
74
+ "\n",
75
+ "loc = 'sqllama-out3'\n",
76
+ "\n",
77
+ "config = PeftConfig.from_pretrained(loc)\n",
78
+ "print(config)\n",
79
+ "\n",
80
+ "\n",
81
+ "tokenizer = LlamaTokenizer.from_pretrained(\n",
82
+ " \"decapoda-research/llama-7b-hf\")\n",
83
+ " \n",
84
+ "tokenizer.pad_token_id = 0\n",
85
+ "tokenizer.padding_side = 'left'\n",
86
+ "\n",
87
+ "model = LlamaForCausalLM.from_pretrained(\n",
88
+ " \"decapoda-research/llama-7b-hf\",\n",
89
+ " load_in_8bit=True,\n",
90
+ " device_map=\"auto\",\n",
91
+ " torch_dtype=torch.float16\n",
92
+ ")\n",
93
+ "\n",
94
+ "model = PeftModel.from_pretrained(\n",
95
+ " model, loc,\n",
96
+ " torch_dtype=torch.float16,\n",
97
+ " device_map=\"auto\"\n",
98
+ " )\n",
99
+ "\n",
100
+ "#model.push_to_hub('LlamaSQL-3')\n",
101
+ "\n",
102
+ "#model = prepare_model_for_int8_training(model)\n",
103
+ "\n",
104
+ "#model = get_peft_model(model,config)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 4,
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "name": "stdout",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "\n",
117
+ "table: 1-20246201-9\n",
118
+ "columns: Candidate,Office,Home state,Popular vote,States – first place,States – second place,States – third place\n",
119
+ "Q: How many states-first place are there for the office of Governor?\n",
120
+ "A: SELECT COUNT States – first place FROM 1-20246201-9 WHERE Office = 'Governor'\n",
121
+ "END\n",
122
+ "\n",
123
+ "\n",
124
+ "table: 2-17429402-7\n",
125
+ "columns: School,Years of Participation,OCC Championships,Last OCC Championship,Last Outright OCC Championship\n",
126
+ "Q: How many times have Central Crossing won the OCC Championship?\n",
127
+ "A: SELECT SUM OCC Championships FROM 2-17429402-7 WHERE School = 'central crossing'\n",
128
+ "END\n",
129
+ "\n",
130
+ "\n",
131
+ "table: 1-11677691-10\n",
132
+ "columns: Player,Position,School,Hometown,College\n",
133
+ "Q: What town is Muscle Shoals High School located in?\n",
134
+ "A: SELECT Hometown FROM 1-11677691-10 WHERE School = 'Muscle Shoals High School'\n",
135
+ "END\n",
136
+ "\n",
137
+ "\n",
138
+ "table: 2-10701914-2\n",
139
+ "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n",
140
+ "Q: What is the largest crowd when st kilda was the away team?\n",
141
+ "A: SELECT MAX Crowd FROM 2-10701914-2 WHERE Away team = 'st kilda'\n",
142
+ "END\n",
143
+ "\n",
144
+ "\n",
145
+ "table: 2-1122152-1\n",
146
+ "columns: Driver,Constructor,Laps,Time/Retired,Grid\n",
147
+ "Q: What is the lap total for the grid under 15 that retired due to transmission?\n",
148
+ "A: SELECT SUM Laps FROM 2-1122152-1 WHERE Grid < 15 AND Time/Retired = 'transmission'\n",
149
+ "END\n",
150
+ "\n"
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "import random\n",
156
+ "import json\n",
157
+ "\n",
158
+ "# defined by WikiSQL\n",
159
+ "\n",
160
+ "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n",
161
+ "cond_ops = ['=', '>', '<', 'OP']\n",
162
+ "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n",
163
+ "\n",
164
+ "def fix_repr(d,cols,types,tid):\n",
165
+ " sel_index=d['sel'] \n",
166
+ " agg_index=d['agg']\n",
167
+ " conditions=d['conds']\n",
168
+ " col = cols[sel_index]\n",
169
+ " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n",
170
+ " agg=agg_ops[agg_index],\n",
171
+ " sel=col,\n",
172
+ " tid=tid\n",
173
+ " )\n",
174
+ " if conditions:\n",
175
+ " cs = []\n",
176
+ " for i, o, v in conditions:\n",
177
+ " #print(i,cols)\n",
178
+ " nm = cols[i]\n",
179
+ " op = cond_ops[o]\n",
180
+ " \n",
181
+ " if types[i] in ['text']:\n",
182
+ " val = f\"\\'{v}\\'\"\n",
183
+ " else:\n",
184
+ " val = v\n",
185
+ " cs.append(f'{nm} {op} {val}')\n",
186
+ " #print(cs)\n",
187
+ "\n",
188
+ " rep += ' WHERE ' + ' AND '.join(cs)\n",
189
+ " \n",
190
+ " return rep\n",
191
+ "\n",
192
+ "tbl_cols = {}\n",
193
+ "tbl_types = {}\n",
194
+ "tbl_str = {}\n",
195
+ "\n",
196
+ "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n",
197
+ "\n",
198
+ "def tbl_def_to_string(id, header, types):\n",
199
+ " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n",
200
+ " return s\n",
201
+ "\n",
202
+ "with open('data/train.tables.jsonl') as f:\n",
203
+ " for line in f:\n",
204
+ " js = json.loads(line)\n",
205
+ " id = js['id']\n",
206
+ " hdr = js['header']\n",
207
+ " ts = js['types']\n",
208
+ " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n",
209
+ " tbl_cols[id] = hdr\n",
210
+ " tbl_types[id] = ts\n",
211
+ "\n",
212
+ "q_s = []\n",
213
+ "a_s = []\n",
214
+ "\n",
215
+ "with open('data/train.jsonl') as f:\n",
216
+ " for line in f:\n",
217
+ " js = json.loads(line)\n",
218
+ " id = js['table_id']\n",
219
+ " s = tbl_str[id]\n",
220
+ " qst = js['question']\n",
221
+ " nl = s + '\\nQ: ' + qst + '\\nA: '\n",
222
+ " q_s.append(nl)\n",
223
+ "\n",
224
+ " sql = js['sql']\n",
225
+ " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
226
+ " a = a + \"\\nEND\\n\"\n",
227
+ " a_s.append(a)\n",
228
+ "\n",
229
+ "M = len(q_s)\n",
230
+ "\n",
231
+ "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n",
232
+ "\n",
233
+ "for i in range(5):\n",
234
+ " j = random.randint(0,M-1)\n",
235
+ " print()\n",
236
+ " print(data_txt[j]) \n",
237
+ " \n",
238
+ " "
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 5,
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "toks = [tokenizer(s) for s in data_txt]\n"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 6,
253
+ "metadata": {},
254
+ "outputs": [
255
+ {
256
+ "name": "stdout",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "92\n",
260
+ " 0\n",
261
+ "count 56355.000000\n",
262
+ "mean 101.219519\n",
263
+ "std 21.740325\n",
264
+ "min 63.000000\n",
265
+ "25% 87.500000\n",
266
+ "50% 97.000000\n",
267
+ "75% 109.000000\n",
268
+ "max 461.000000\n",
269
+ "32084\n"
270
+ ]
271
+ }
272
+ ],
273
+ "source": [
274
+ "import numpy as np\n",
275
+ "import pandas as pd\n",
276
+ "\n",
277
+ "print(len(toks[0].input_ids))\n",
278
+ "lens = np.array([len(tok.input_ids) for tok in toks])\n",
279
+ "print(pd.DataFrame(lens).describe())\n",
280
+ "\n",
281
+ "z = zip(q_s,lens)\n",
282
+ "q_red = [a for a,b in z if b < 100]\n",
283
+ "z = zip(a_s,lens)\n",
284
+ "a_red = [a for a,b in z if b < 100]\n",
285
+ "\n",
286
+ "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n",
287
+ "print(len(data_red))\n",
288
+ "\n"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 7,
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "data": {
298
+ "application/vnd.jupyter.widget-view+json": {
299
+ "model_id": "3215cace7aef45b1b040a12f11509a7d",
300
+ "version_major": 2,
301
+ "version_minor": 0
302
+ },
303
+ "text/plain": [
304
+ "Map: 0%| | 0/32084 [00:00<?, ? examples/s]"
305
+ ]
306
+ },
307
+ "metadata": {},
308
+ "output_type": "display_data"
309
+ }
310
+ ],
311
+ "source": [
312
+ "import random, datasets\n",
313
+ "#d = {'prompt': random.sample(data_red, 1000)}\n",
314
+ "d = {'prompt': data_red}\n",
315
+ "\n",
316
+ "data = datasets.Dataset.from_dict(d)\n",
317
+ "data = data.map(lambda x:\n",
318
+ " tokenizer(\n",
319
+ " x['prompt'],\n",
320
+ " truncation=True,\n",
321
+ " max_length=100,\n",
322
+ " padding=\"max_length\"\n",
323
+ " ))\n",
324
+ "\n",
325
+ "data = data.remove_columns('prompt')\n"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 8,
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "#from peft import get_peft_model,PrefixTuningConfig, TaskType, PeftType\n",
335
+ "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
336
+ "import torch\n",
337
+ "import transformers\n",
338
+ "import datasets\n",
339
+ "\n",
340
+ "BATCH = 128\n",
341
+ "MICRO_BATCH = 4\n",
342
+ "N_GAS = BATCH//MICRO_BATCH\n",
343
+ "EPOCHS = 1\n",
344
+ "LR = 1e-6\n",
345
+ "\n",
346
+ "#peft_cfg = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n",
347
+ "#model = get_peft_model(model,peft_config)\n",
348
+ "#model = model.to(torch.device('cuda'))\n",
349
+ "\n",
350
+ "targs = transformers.TrainingArguments(\n",
351
+ " per_device_train_batch_size=MICRO_BATCH,\n",
352
+ " gradient_accumulation_steps=N_GAS,\n",
353
+ " warmup_steps=20,\n",
354
+ " num_train_epochs=EPOCHS,\n",
355
+ " learning_rate=LR,\n",
356
+ " fp16=True,\n",
357
+ " logging_steps=1,\n",
358
+ " output_dir='sqllama-out3-rt',\n",
359
+ " save_total_limit=3,\n",
360
+ " remove_unused_columns=False,\n",
361
+ " \n",
362
+ ")\n"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 9,
368
+ "metadata": {},
369
+ "outputs": [
370
+ {
371
+ "data": {
372
+ "text/html": [
373
+ "\n",
374
+ " <div>\n",
375
+ " \n",
376
+ " <progress value='250' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
377
+ " [250/250 3:05:44, Epoch 0/1]\n",
378
+ " </div>\n",
379
+ " <table border=\"1\" class=\"dataframe\">\n",
380
+ " <thead>\n",
381
+ " <tr style=\"text-align: left;\">\n",
382
+ " <th>Step</th>\n",
383
+ " <th>Training Loss</th>\n",
384
+ " </tr>\n",
385
+ " </thead>\n",
386
+ " <tbody>\n",
387
+ " <tr>\n",
388
+ " <td>1</td>\n",
389
+ " <td>1.121900</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <td>2</td>\n",
393
+ " <td>1.103000</td>\n",
394
+ " </tr>\n",
395
+ " <tr>\n",
396
+ " <td>3</td>\n",
397
+ " <td>1.128600</td>\n",
398
+ " </tr>\n",
399
+ " <tr>\n",
400
+ " <td>4</td>\n",
401
+ " <td>1.098400</td>\n",
402
+ " </tr>\n",
403
+ " <tr>\n",
404
+ " <td>5</td>\n",
405
+ " <td>1.094800</td>\n",
406
+ " </tr>\n",
407
+ " <tr>\n",
408
+ " <td>6</td>\n",
409
+ " <td>1.135100</td>\n",
410
+ " </tr>\n",
411
+ " <tr>\n",
412
+ " <td>7</td>\n",
413
+ " <td>1.114100</td>\n",
414
+ " </tr>\n",
415
+ " <tr>\n",
416
+ " <td>8</td>\n",
417
+ " <td>1.125100</td>\n",
418
+ " </tr>\n",
419
+ " <tr>\n",
420
+ " <td>9</td>\n",
421
+ " <td>1.122000</td>\n",
422
+ " </tr>\n",
423
+ " <tr>\n",
424
+ " <td>10</td>\n",
425
+ " <td>1.138500</td>\n",
426
+ " </tr>\n",
427
+ " <tr>\n",
428
+ " <td>11</td>\n",
429
+ " <td>1.114100</td>\n",
430
+ " </tr>\n",
431
+ " <tr>\n",
432
+ " <td>12</td>\n",
433
+ " <td>1.124700</td>\n",
434
+ " </tr>\n",
435
+ " <tr>\n",
436
+ " <td>13</td>\n",
437
+ " <td>1.133100</td>\n",
438
+ " </tr>\n",
439
+ " <tr>\n",
440
+ " <td>14</td>\n",
441
+ " <td>1.118900</td>\n",
442
+ " </tr>\n",
443
+ " <tr>\n",
444
+ " <td>15</td>\n",
445
+ " <td>1.121700</td>\n",
446
+ " </tr>\n",
447
+ " <tr>\n",
448
+ " <td>16</td>\n",
449
+ " <td>1.131600</td>\n",
450
+ " </tr>\n",
451
+ " <tr>\n",
452
+ " <td>17</td>\n",
453
+ " <td>1.137700</td>\n",
454
+ " </tr>\n",
455
+ " <tr>\n",
456
+ " <td>18</td>\n",
457
+ " <td>1.147000</td>\n",
458
+ " </tr>\n",
459
+ " <tr>\n",
460
+ " <td>19</td>\n",
461
+ " <td>1.113200</td>\n",
462
+ " </tr>\n",
463
+ " <tr>\n",
464
+ " <td>20</td>\n",
465
+ " <td>1.158000</td>\n",
466
+ " </tr>\n",
467
+ " <tr>\n",
468
+ " <td>21</td>\n",
469
+ " <td>1.114100</td>\n",
470
+ " </tr>\n",
471
+ " <tr>\n",
472
+ " <td>22</td>\n",
473
+ " <td>1.117700</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <td>23</td>\n",
477
+ " <td>1.151200</td>\n",
478
+ " </tr>\n",
479
+ " <tr>\n",
480
+ " <td>24</td>\n",
481
+ " <td>1.154200</td>\n",
482
+ " </tr>\n",
483
+ " <tr>\n",
484
+ " <td>25</td>\n",
485
+ " <td>1.153600</td>\n",
486
+ " </tr>\n",
487
+ " <tr>\n",
488
+ " <td>26</td>\n",
489
+ " <td>1.119700</td>\n",
490
+ " </tr>\n",
491
+ " <tr>\n",
492
+ " <td>27</td>\n",
493
+ " <td>1.138500</td>\n",
494
+ " </tr>\n",
495
+ " <tr>\n",
496
+ " <td>28</td>\n",
497
+ " <td>1.148200</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <td>29</td>\n",
501
+ " <td>1.103300</td>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <td>30</td>\n",
505
+ " <td>1.117700</td>\n",
506
+ " </tr>\n",
507
+ " <tr>\n",
508
+ " <td>31</td>\n",
509
+ " <td>1.116300</td>\n",
510
+ " </tr>\n",
511
+ " <tr>\n",
512
+ " <td>32</td>\n",
513
+ " <td>1.134400</td>\n",
514
+ " </tr>\n",
515
+ " <tr>\n",
516
+ " <td>33</td>\n",
517
+ " <td>1.125200</td>\n",
518
+ " </tr>\n",
519
+ " <tr>\n",
520
+ " <td>34</td>\n",
521
+ " <td>1.137900</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <td>35</td>\n",
525
+ " <td>1.150100</td>\n",
526
+ " </tr>\n",
527
+ " <tr>\n",
528
+ " <td>36</td>\n",
529
+ " <td>1.126200</td>\n",
530
+ " </tr>\n",
531
+ " <tr>\n",
532
+ " <td>37</td>\n",
533
+ " <td>1.129100</td>\n",
534
+ " </tr>\n",
535
+ " <tr>\n",
536
+ " <td>38</td>\n",
537
+ " <td>1.093200</td>\n",
538
+ " </tr>\n",
539
+ " <tr>\n",
540
+ " <td>39</td>\n",
541
+ " <td>1.153100</td>\n",
542
+ " </tr>\n",
543
+ " <tr>\n",
544
+ " <td>40</td>\n",
545
+ " <td>1.108000</td>\n",
546
+ " </tr>\n",
547
+ " <tr>\n",
548
+ " <td>41</td>\n",
549
+ " <td>1.137300</td>\n",
550
+ " </tr>\n",
551
+ " <tr>\n",
552
+ " <td>42</td>\n",
553
+ " <td>1.101600</td>\n",
554
+ " </tr>\n",
555
+ " <tr>\n",
556
+ " <td>43</td>\n",
557
+ " <td>1.140600</td>\n",
558
+ " </tr>\n",
559
+ " <tr>\n",
560
+ " <td>44</td>\n",
561
+ " <td>1.159900</td>\n",
562
+ " </tr>\n",
563
+ " <tr>\n",
564
+ " <td>45</td>\n",
565
+ " <td>1.112600</td>\n",
566
+ " </tr>\n",
567
+ " <tr>\n",
568
+ " <td>46</td>\n",
569
+ " <td>1.101200</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <td>47</td>\n",
573
+ " <td>1.088000</td>\n",
574
+ " </tr>\n",
575
+ " <tr>\n",
576
+ " <td>48</td>\n",
577
+ " <td>1.135300</td>\n",
578
+ " </tr>\n",
579
+ " <tr>\n",
580
+ " <td>49</td>\n",
581
+ " <td>1.118100</td>\n",
582
+ " </tr>\n",
583
+ " <tr>\n",
584
+ " <td>50</td>\n",
585
+ " <td>1.140300</td>\n",
586
+ " </tr>\n",
587
+ " <tr>\n",
588
+ " <td>51</td>\n",
589
+ " <td>1.104000</td>\n",
590
+ " </tr>\n",
591
+ " <tr>\n",
592
+ " <td>52</td>\n",
593
+ " <td>1.122900</td>\n",
594
+ " </tr>\n",
595
+ " <tr>\n",
596
+ " <td>53</td>\n",
597
+ " <td>1.162200</td>\n",
598
+ " </tr>\n",
599
+ " <tr>\n",
600
+ " <td>54</td>\n",
601
+ " <td>1.108500</td>\n",
602
+ " </tr>\n",
603
+ " <tr>\n",
604
+ " <td>55</td>\n",
605
+ " <td>1.121900</td>\n",
606
+ " </tr>\n",
607
+ " <tr>\n",
608
+ " <td>56</td>\n",
609
+ " <td>1.092100</td>\n",
610
+ " </tr>\n",
611
+ " <tr>\n",
612
+ " <td>57</td>\n",
613
+ " <td>1.109500</td>\n",
614
+ " </tr>\n",
615
+ " <tr>\n",
616
+ " <td>58</td>\n",
617
+ " <td>1.139400</td>\n",
618
+ " </tr>\n",
619
+ " <tr>\n",
620
+ " <td>59</td>\n",
621
+ " <td>1.120800</td>\n",
622
+ " </tr>\n",
623
+ " <tr>\n",
624
+ " <td>60</td>\n",
625
+ " <td>1.132200</td>\n",
626
+ " </tr>\n",
627
+ " <tr>\n",
628
+ " <td>61</td>\n",
629
+ " <td>1.138700</td>\n",
630
+ " </tr>\n",
631
+ " <tr>\n",
632
+ " <td>62</td>\n",
633
+ " <td>1.128700</td>\n",
634
+ " </tr>\n",
635
+ " <tr>\n",
636
+ " <td>63</td>\n",
637
+ " <td>1.122500</td>\n",
638
+ " </tr>\n",
639
+ " <tr>\n",
640
+ " <td>64</td>\n",
641
+ " <td>1.145800</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <td>65</td>\n",
645
+ " <td>1.135000</td>\n",
646
+ " </tr>\n",
647
+ " <tr>\n",
648
+ " <td>66</td>\n",
649
+ " <td>1.107900</td>\n",
650
+ " </tr>\n",
651
+ " <tr>\n",
652
+ " <td>67</td>\n",
653
+ " <td>1.120700</td>\n",
654
+ " </tr>\n",
655
+ " <tr>\n",
656
+ " <td>68</td>\n",
657
+ " <td>1.128000</td>\n",
658
+ " </tr>\n",
659
+ " <tr>\n",
660
+ " <td>69</td>\n",
661
+ " <td>1.107600</td>\n",
662
+ " </tr>\n",
663
+ " <tr>\n",
664
+ " <td>70</td>\n",
665
+ " <td>1.155700</td>\n",
666
+ " </tr>\n",
667
+ " <tr>\n",
668
+ " <td>71</td>\n",
669
+ " <td>1.142400</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <td>72</td>\n",
673
+ " <td>1.118900</td>\n",
674
+ " </tr>\n",
675
+ " <tr>\n",
676
+ " <td>73</td>\n",
677
+ " <td>1.129900</td>\n",
678
+ " </tr>\n",
679
+ " <tr>\n",
680
+ " <td>74</td>\n",
681
+ " <td>1.134400</td>\n",
682
+ " </tr>\n",
683
+ " <tr>\n",
684
+ " <td>75</td>\n",
685
+ " <td>1.105500</td>\n",
686
+ " </tr>\n",
687
+ " <tr>\n",
688
+ " <td>76</td>\n",
689
+ " <td>1.104100</td>\n",
690
+ " </tr>\n",
691
+ " <tr>\n",
692
+ " <td>77</td>\n",
693
+ " <td>1.100900</td>\n",
694
+ " </tr>\n",
695
+ " <tr>\n",
696
+ " <td>78</td>\n",
697
+ " <td>1.148200</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <td>79</td>\n",
701
+ " <td>1.116100</td>\n",
702
+ " </tr>\n",
703
+ " <tr>\n",
704
+ " <td>80</td>\n",
705
+ " <td>1.121700</td>\n",
706
+ " </tr>\n",
707
+ " <tr>\n",
708
+ " <td>81</td>\n",
709
+ " <td>1.154100</td>\n",
710
+ " </tr>\n",
711
+ " <tr>\n",
712
+ " <td>82</td>\n",
713
+ " <td>1.118900</td>\n",
714
+ " </tr>\n",
715
+ " <tr>\n",
716
+ " <td>83</td>\n",
717
+ " <td>1.109600</td>\n",
718
+ " </tr>\n",
719
+ " <tr>\n",
720
+ " <td>84</td>\n",
721
+ " <td>1.109300</td>\n",
722
+ " </tr>\n",
723
+ " <tr>\n",
724
+ " <td>85</td>\n",
725
+ " <td>1.147900</td>\n",
726
+ " </tr>\n",
727
+ " <tr>\n",
728
+ " <td>86</td>\n",
729
+ " <td>1.094300</td>\n",
730
+ " </tr>\n",
731
+ " <tr>\n",
732
+ " <td>87</td>\n",
733
+ " <td>1.130000</td>\n",
734
+ " </tr>\n",
735
+ " <tr>\n",
736
+ " <td>88</td>\n",
737
+ " <td>1.095100</td>\n",
738
+ " </tr>\n",
739
+ " <tr>\n",
740
+ " <td>89</td>\n",
741
+ " <td>1.145900</td>\n",
742
+ " </tr>\n",
743
+ " <tr>\n",
744
+ " <td>90</td>\n",
745
+ " <td>1.131600</td>\n",
746
+ " </tr>\n",
747
+ " <tr>\n",
748
+ " <td>91</td>\n",
749
+ " <td>1.114200</td>\n",
750
+ " </tr>\n",
751
+ " <tr>\n",
752
+ " <td>92</td>\n",
753
+ " <td>1.126600</td>\n",
754
+ " </tr>\n",
755
+ " <tr>\n",
756
+ " <td>93</td>\n",
757
+ " <td>1.100300</td>\n",
758
+ " </tr>\n",
759
+ " <tr>\n",
760
+ " <td>94</td>\n",
761
+ " <td>1.140900</td>\n",
762
+ " </tr>\n",
763
+ " <tr>\n",
764
+ " <td>95</td>\n",
765
+ " <td>1.132800</td>\n",
766
+ " </tr>\n",
767
+ " <tr>\n",
768
+ " <td>96</td>\n",
769
+ " <td>1.105900</td>\n",
770
+ " </tr>\n",
771
+ " <tr>\n",
772
+ " <td>97</td>\n",
773
+ " <td>1.106200</td>\n",
774
+ " </tr>\n",
775
+ " <tr>\n",
776
+ " <td>98</td>\n",
777
+ " <td>1.097400</td>\n",
778
+ " </tr>\n",
779
+ " <tr>\n",
780
+ " <td>99</td>\n",
781
+ " <td>1.114500</td>\n",
782
+ " </tr>\n",
783
+ " <tr>\n",
784
+ " <td>100</td>\n",
785
+ " <td>1.113700</td>\n",
786
+ " </tr>\n",
787
+ " <tr>\n",
788
+ " <td>101</td>\n",
789
+ " <td>1.093300</td>\n",
790
+ " </tr>\n",
791
+ " <tr>\n",
792
+ " <td>102</td>\n",
793
+ " <td>1.121900</td>\n",
794
+ " </tr>\n",
795
+ " <tr>\n",
796
+ " <td>103</td>\n",
797
+ " <td>1.133600</td>\n",
798
+ " </tr>\n",
799
+ " <tr>\n",
800
+ " <td>104</td>\n",
801
+ " <td>1.131500</td>\n",
802
+ " </tr>\n",
803
+ " <tr>\n",
804
+ " <td>105</td>\n",
805
+ " <td>1.136800</td>\n",
806
+ " </tr>\n",
807
+ " <tr>\n",
808
+ " <td>106</td>\n",
809
+ " <td>1.130800</td>\n",
810
+ " </tr>\n",
811
+ " <tr>\n",
812
+ " <td>107</td>\n",
813
+ " <td>1.102100</td>\n",
814
+ " </tr>\n",
815
+ " <tr>\n",
816
+ " <td>108</td>\n",
817
+ " <td>1.128300</td>\n",
818
+ " </tr>\n",
819
+ " <tr>\n",
820
+ " <td>109</td>\n",
821
+ " <td>1.163500</td>\n",
822
+ " </tr>\n",
823
+ " <tr>\n",
824
+ " <td>110</td>\n",
825
+ " <td>1.144200</td>\n",
826
+ " </tr>\n",
827
+ " <tr>\n",
828
+ " <td>111</td>\n",
829
+ " <td>1.125600</td>\n",
830
+ " </tr>\n",
831
+ " <tr>\n",
832
+ " <td>112</td>\n",
833
+ " <td>1.119700</td>\n",
834
+ " </tr>\n",
835
+ " <tr>\n",
836
+ " <td>113</td>\n",
837
+ " <td>1.111100</td>\n",
838
+ " </tr>\n",
839
+ " <tr>\n",
840
+ " <td>114</td>\n",
841
+ " <td>1.122400</td>\n",
842
+ " </tr>\n",
843
+ " <tr>\n",
844
+ " <td>115</td>\n",
845
+ " <td>1.142500</td>\n",
846
+ " </tr>\n",
847
+ " <tr>\n",
848
+ " <td>116</td>\n",
849
+ " <td>1.124500</td>\n",
850
+ " </tr>\n",
851
+ " <tr>\n",
852
+ " <td>117</td>\n",
853
+ " <td>1.117700</td>\n",
854
+ " </tr>\n",
855
+ " <tr>\n",
856
+ " <td>118</td>\n",
857
+ " <td>1.130500</td>\n",
858
+ " </tr>\n",
859
+ " <tr>\n",
860
+ " <td>119</td>\n",
861
+ " <td>1.118500</td>\n",
862
+ " </tr>\n",
863
+ " <tr>\n",
864
+ " <td>120</td>\n",
865
+ " <td>1.097200</td>\n",
866
+ " </tr>\n",
867
+ " <tr>\n",
868
+ " <td>121</td>\n",
869
+ " <td>1.123600</td>\n",
870
+ " </tr>\n",
871
+ " <tr>\n",
872
+ " <td>122</td>\n",
873
+ " <td>1.135700</td>\n",
874
+ " </tr>\n",
875
+ " <tr>\n",
876
+ " <td>123</td>\n",
877
+ " <td>1.153400</td>\n",
878
+ " </tr>\n",
879
+ " <tr>\n",
880
+ " <td>124</td>\n",
881
+ " <td>1.088200</td>\n",
882
+ " </tr>\n",
883
+ " <tr>\n",
884
+ " <td>125</td>\n",
885
+ " <td>1.123600</td>\n",
886
+ " </tr>\n",
887
+ " <tr>\n",
888
+ " <td>126</td>\n",
889
+ " <td>1.143000</td>\n",
890
+ " </tr>\n",
891
+ " <tr>\n",
892
+ " <td>127</td>\n",
893
+ " <td>1.121800</td>\n",
894
+ " </tr>\n",
895
+ " <tr>\n",
896
+ " <td>128</td>\n",
897
+ " <td>1.091200</td>\n",
898
+ " </tr>\n",
899
+ " <tr>\n",
900
+ " <td>129</td>\n",
901
+ " <td>1.116700</td>\n",
902
+ " </tr>\n",
903
+ " <tr>\n",
904
+ " <td>130</td>\n",
905
+ " <td>1.124400</td>\n",
906
+ " </tr>\n",
907
+ " <tr>\n",
908
+ " <td>131</td>\n",
909
+ " <td>1.139100</td>\n",
910
+ " </tr>\n",
911
+ " <tr>\n",
912
+ " <td>132</td>\n",
913
+ " <td>1.119400</td>\n",
914
+ " </tr>\n",
915
+ " <tr>\n",
916
+ " <td>133</td>\n",
917
+ " <td>1.115000</td>\n",
918
+ " </tr>\n",
919
+ " <tr>\n",
920
+ " <td>134</td>\n",
921
+ " <td>1.133600</td>\n",
922
+ " </tr>\n",
923
+ " <tr>\n",
924
+ " <td>135</td>\n",
925
+ " <td>1.100900</td>\n",
926
+ " </tr>\n",
927
+ " <tr>\n",
928
+ " <td>136</td>\n",
929
+ " <td>1.095100</td>\n",
930
+ " </tr>\n",
931
+ " <tr>\n",
932
+ " <td>137</td>\n",
933
+ " <td>1.142600</td>\n",
934
+ " </tr>\n",
935
+ " <tr>\n",
936
+ " <td>138</td>\n",
937
+ " <td>1.097300</td>\n",
938
+ " </tr>\n",
939
+ " <tr>\n",
940
+ " <td>139</td>\n",
941
+ " <td>1.113100</td>\n",
942
+ " </tr>\n",
943
+ " <tr>\n",
944
+ " <td>140</td>\n",
945
+ " <td>1.150800</td>\n",
946
+ " </tr>\n",
947
+ " <tr>\n",
948
+ " <td>141</td>\n",
949
+ " <td>1.149600</td>\n",
950
+ " </tr>\n",
951
+ " <tr>\n",
952
+ " <td>142</td>\n",
953
+ " <td>1.106700</td>\n",
954
+ " </tr>\n",
955
+ " <tr>\n",
956
+ " <td>143</td>\n",
957
+ " <td>1.086100</td>\n",
958
+ " </tr>\n",
959
+ " <tr>\n",
960
+ " <td>144</td>\n",
961
+ " <td>1.134200</td>\n",
962
+ " </tr>\n",
963
+ " <tr>\n",
964
+ " <td>145</td>\n",
965
+ " <td>1.096400</td>\n",
966
+ " </tr>\n",
967
+ " <tr>\n",
968
+ " <td>146</td>\n",
969
+ " <td>1.099200</td>\n",
970
+ " </tr>\n",
971
+ " <tr>\n",
972
+ " <td>147</td>\n",
973
+ " <td>1.168300</td>\n",
974
+ " </tr>\n",
975
+ " <tr>\n",
976
+ " <td>148</td>\n",
977
+ " <td>1.105900</td>\n",
978
+ " </tr>\n",
979
+ " <tr>\n",
980
+ " <td>149</td>\n",
981
+ " <td>1.119700</td>\n",
982
+ " </tr>\n",
983
+ " <tr>\n",
984
+ " <td>150</td>\n",
985
+ " <td>1.100200</td>\n",
986
+ " </tr>\n",
987
+ " <tr>\n",
988
+ " <td>151</td>\n",
989
+ " <td>1.089600</td>\n",
990
+ " </tr>\n",
991
+ " <tr>\n",
992
+ " <td>152</td>\n",
993
+ " <td>1.128200</td>\n",
994
+ " </tr>\n",
995
+ " <tr>\n",
996
+ " <td>153</td>\n",
997
+ " <td>1.148300</td>\n",
998
+ " </tr>\n",
999
+ " <tr>\n",
1000
+ " <td>154</td>\n",
1001
+ " <td>1.119800</td>\n",
1002
+ " </tr>\n",
1003
+ " <tr>\n",
1004
+ " <td>155</td>\n",
1005
+ " <td>1.102700</td>\n",
1006
+ " </tr>\n",
1007
+ " <tr>\n",
1008
+ " <td>156</td>\n",
1009
+ " <td>1.107800</td>\n",
1010
+ " </tr>\n",
1011
+ " <tr>\n",
1012
+ " <td>157</td>\n",
1013
+ " <td>1.113100</td>\n",
1014
+ " </tr>\n",
1015
+ " <tr>\n",
1016
+ " <td>158</td>\n",
1017
+ " <td>1.156100</td>\n",
1018
+ " </tr>\n",
1019
+ " <tr>\n",
1020
+ " <td>159</td>\n",
1021
+ " <td>1.091500</td>\n",
1022
+ " </tr>\n",
1023
+ " <tr>\n",
1024
+ " <td>160</td>\n",
1025
+ " <td>1.118000</td>\n",
1026
+ " </tr>\n",
1027
+ " <tr>\n",
1028
+ " <td>161</td>\n",
1029
+ " <td>1.145600</td>\n",
1030
+ " </tr>\n",
1031
+ " <tr>\n",
1032
+ " <td>162</td>\n",
1033
+ " <td>1.115400</td>\n",
1034
+ " </tr>\n",
1035
+ " <tr>\n",
1036
+ " <td>163</td>\n",
1037
+ " <td>1.121900</td>\n",
1038
+ " </tr>\n",
1039
+ " <tr>\n",
1040
+ " <td>164</td>\n",
1041
+ " <td>1.130100</td>\n",
1042
+ " </tr>\n",
1043
+ " <tr>\n",
1044
+ " <td>165</td>\n",
1045
+ " <td>1.123400</td>\n",
1046
+ " </tr>\n",
1047
+ " <tr>\n",
1048
+ " <td>166</td>\n",
1049
+ " <td>1.090900</td>\n",
1050
+ " </tr>\n",
1051
+ " <tr>\n",
1052
+ " <td>167</td>\n",
1053
+ " <td>1.144400</td>\n",
1054
+ " </tr>\n",
1055
+ " <tr>\n",
1056
+ " <td>168</td>\n",
1057
+ " <td>1.125100</td>\n",
1058
+ " </tr>\n",
1059
+ " <tr>\n",
1060
+ " <td>169</td>\n",
1061
+ " <td>1.110700</td>\n",
1062
+ " </tr>\n",
1063
+ " <tr>\n",
1064
+ " <td>170</td>\n",
1065
+ " <td>1.134300</td>\n",
1066
+ " </tr>\n",
1067
+ " <tr>\n",
1068
+ " <td>171</td>\n",
1069
+ " <td>1.092600</td>\n",
1070
+ " </tr>\n",
1071
+ " <tr>\n",
1072
+ " <td>172</td>\n",
1073
+ " <td>1.123000</td>\n",
1074
+ " </tr>\n",
1075
+ " <tr>\n",
1076
+ " <td>173</td>\n",
1077
+ " <td>1.080100</td>\n",
1078
+ " </tr>\n",
1079
+ " <tr>\n",
1080
+ " <td>174</td>\n",
1081
+ " <td>1.104100</td>\n",
1082
+ " </tr>\n",
1083
+ " <tr>\n",
1084
+ " <td>175</td>\n",
1085
+ " <td>1.105800</td>\n",
1086
+ " </tr>\n",
1087
+ " <tr>\n",
1088
+ " <td>176</td>\n",
1089
+ " <td>1.156000</td>\n",
1090
+ " </tr>\n",
1091
+ " <tr>\n",
1092
+ " <td>177</td>\n",
1093
+ " <td>1.104000</td>\n",
1094
+ " </tr>\n",
1095
+ " <tr>\n",
1096
+ " <td>178</td>\n",
1097
+ " <td>1.118500</td>\n",
1098
+ " </tr>\n",
1099
+ " <tr>\n",
1100
+ " <td>179</td>\n",
1101
+ " <td>1.123100</td>\n",
1102
+ " </tr>\n",
1103
+ " <tr>\n",
1104
+ " <td>180</td>\n",
1105
+ " <td>1.117000</td>\n",
1106
+ " </tr>\n",
1107
+ " <tr>\n",
1108
+ " <td>181</td>\n",
1109
+ " <td>1.122100</td>\n",
1110
+ " </tr>\n",
1111
+ " <tr>\n",
1112
+ " <td>182</td>\n",
1113
+ " <td>1.141200</td>\n",
1114
+ " </tr>\n",
1115
+ " <tr>\n",
1116
+ " <td>183</td>\n",
1117
+ " <td>1.135600</td>\n",
1118
+ " </tr>\n",
1119
+ " <tr>\n",
1120
+ " <td>184</td>\n",
1121
+ " <td>1.093600</td>\n",
1122
+ " </tr>\n",
1123
+ " <tr>\n",
1124
+ " <td>185</td>\n",
1125
+ " <td>1.156300</td>\n",
1126
+ " </tr>\n",
1127
+ " <tr>\n",
1128
+ " <td>186</td>\n",
1129
+ " <td>1.095600</td>\n",
1130
+ " </tr>\n",
1131
+ " <tr>\n",
1132
+ " <td>187</td>\n",
1133
+ " <td>1.128900</td>\n",
1134
+ " </tr>\n",
1135
+ " <tr>\n",
1136
+ " <td>188</td>\n",
1137
+ " <td>1.101200</td>\n",
1138
+ " </tr>\n",
1139
+ " <tr>\n",
1140
+ " <td>189</td>\n",
1141
+ " <td>1.149900</td>\n",
1142
+ " </tr>\n",
1143
+ " <tr>\n",
1144
+ " <td>190</td>\n",
1145
+ " <td>1.112300</td>\n",
1146
+ " </tr>\n",
1147
+ " <tr>\n",
1148
+ " <td>191</td>\n",
1149
+ " <td>1.117600</td>\n",
1150
+ " </tr>\n",
1151
+ " <tr>\n",
1152
+ " <td>192</td>\n",
1153
+ " <td>1.090600</td>\n",
1154
+ " </tr>\n",
1155
+ " <tr>\n",
1156
+ " <td>193</td>\n",
1157
+ " <td>1.097700</td>\n",
1158
+ " </tr>\n",
1159
+ " <tr>\n",
1160
+ " <td>194</td>\n",
1161
+ " <td>1.084700</td>\n",
1162
+ " </tr>\n",
1163
+ " <tr>\n",
1164
+ " <td>195</td>\n",
1165
+ " <td>1.128900</td>\n",
1166
+ " </tr>\n",
1167
+ " <tr>\n",
1168
+ " <td>196</td>\n",
1169
+ " <td>1.126400</td>\n",
1170
+ " </tr>\n",
1171
+ " <tr>\n",
1172
+ " <td>197</td>\n",
1173
+ " <td>1.113000</td>\n",
1174
+ " </tr>\n",
1175
+ " <tr>\n",
1176
+ " <td>198</td>\n",
1177
+ " <td>1.107500</td>\n",
1178
+ " </tr>\n",
1179
+ " <tr>\n",
1180
+ " <td>199</td>\n",
1181
+ " <td>1.160100</td>\n",
1182
+ " </tr>\n",
1183
+ " <tr>\n",
1184
+ " <td>200</td>\n",
1185
+ " <td>1.125800</td>\n",
1186
+ " </tr>\n",
1187
+ " <tr>\n",
1188
+ " <td>201</td>\n",
1189
+ " <td>1.125300</td>\n",
1190
+ " </tr>\n",
1191
+ " <tr>\n",
1192
+ " <td>202</td>\n",
1193
+ " <td>1.127200</td>\n",
1194
+ " </tr>\n",
1195
+ " <tr>\n",
1196
+ " <td>203</td>\n",
1197
+ " <td>1.114200</td>\n",
1198
+ " </tr>\n",
1199
+ " <tr>\n",
1200
+ " <td>204</td>\n",
1201
+ " <td>1.114300</td>\n",
1202
+ " </tr>\n",
1203
+ " <tr>\n",
1204
+ " <td>205</td>\n",
1205
+ " <td>1.119200</td>\n",
1206
+ " </tr>\n",
1207
+ " <tr>\n",
1208
+ " <td>206</td>\n",
1209
+ " <td>1.114500</td>\n",
1210
+ " </tr>\n",
1211
+ " <tr>\n",
1212
+ " <td>207</td>\n",
1213
+ " <td>1.086100</td>\n",
1214
+ " </tr>\n",
1215
+ " <tr>\n",
1216
+ " <td>208</td>\n",
1217
+ " <td>1.096200</td>\n",
1218
+ " </tr>\n",
1219
+ " <tr>\n",
1220
+ " <td>209</td>\n",
1221
+ " <td>1.115800</td>\n",
1222
+ " </tr>\n",
1223
+ " <tr>\n",
1224
+ " <td>210</td>\n",
1225
+ " <td>1.094500</td>\n",
1226
+ " </tr>\n",
1227
+ " <tr>\n",
1228
+ " <td>211</td>\n",
1229
+ " <td>1.106400</td>\n",
1230
+ " </tr>\n",
1231
+ " <tr>\n",
1232
+ " <td>212</td>\n",
1233
+ " <td>1.121400</td>\n",
1234
+ " </tr>\n",
1235
+ " <tr>\n",
1236
+ " <td>213</td>\n",
1237
+ " <td>1.137600</td>\n",
1238
+ " </tr>\n",
1239
+ " <tr>\n",
1240
+ " <td>214</td>\n",
1241
+ " <td>1.107000</td>\n",
1242
+ " </tr>\n",
1243
+ " <tr>\n",
1244
+ " <td>215</td>\n",
1245
+ " <td>1.095700</td>\n",
1246
+ " </tr>\n",
1247
+ " <tr>\n",
1248
+ " <td>216</td>\n",
1249
+ " <td>1.083000</td>\n",
1250
+ " </tr>\n",
1251
+ " <tr>\n",
1252
+ " <td>217</td>\n",
1253
+ " <td>1.088700</td>\n",
1254
+ " </tr>\n",
1255
+ " <tr>\n",
1256
+ " <td>218</td>\n",
1257
+ " <td>1.133700</td>\n",
1258
+ " </tr>\n",
1259
+ " <tr>\n",
1260
+ " <td>219</td>\n",
1261
+ " <td>1.115500</td>\n",
1262
+ " </tr>\n",
1263
+ " <tr>\n",
1264
+ " <td>220</td>\n",
1265
+ " <td>1.152900</td>\n",
1266
+ " </tr>\n",
1267
+ " <tr>\n",
1268
+ " <td>221</td>\n",
1269
+ " <td>1.100100</td>\n",
1270
+ " </tr>\n",
1271
+ " <tr>\n",
1272
+ " <td>222</td>\n",
1273
+ " <td>1.112500</td>\n",
1274
+ " </tr>\n",
1275
+ " <tr>\n",
1276
+ " <td>223</td>\n",
1277
+ " <td>1.119200</td>\n",
1278
+ " </tr>\n",
1279
+ " <tr>\n",
1280
+ " <td>224</td>\n",
1281
+ " <td>1.122600</td>\n",
1282
+ " </tr>\n",
1283
+ " <tr>\n",
1284
+ " <td>225</td>\n",
1285
+ " <td>1.100100</td>\n",
1286
+ " </tr>\n",
1287
+ " <tr>\n",
1288
+ " <td>226</td>\n",
1289
+ " <td>1.082500</td>\n",
1290
+ " </tr>\n",
1291
+ " <tr>\n",
1292
+ " <td>227</td>\n",
1293
+ " <td>1.094800</td>\n",
1294
+ " </tr>\n",
1295
+ " <tr>\n",
1296
+ " <td>228</td>\n",
1297
+ " <td>1.123600</td>\n",
1298
+ " </tr>\n",
1299
+ " <tr>\n",
1300
+ " <td>229</td>\n",
1301
+ " <td>1.124700</td>\n",
1302
+ " </tr>\n",
1303
+ " <tr>\n",
1304
+ " <td>230</td>\n",
1305
+ " <td>1.148800</td>\n",
1306
+ " </tr>\n",
1307
+ " <tr>\n",
1308
+ " <td>231</td>\n",
1309
+ " <td>1.109600</td>\n",
1310
+ " </tr>\n",
1311
+ " <tr>\n",
1312
+ " <td>232</td>\n",
1313
+ " <td>1.096100</td>\n",
1314
+ " </tr>\n",
1315
+ " <tr>\n",
1316
+ " <td>233</td>\n",
1317
+ " <td>1.123000</td>\n",
1318
+ " </tr>\n",
1319
+ " <tr>\n",
1320
+ " <td>234</td>\n",
1321
+ " <td>1.102200</td>\n",
1322
+ " </tr>\n",
1323
+ " <tr>\n",
1324
+ " <td>235</td>\n",
1325
+ " <td>1.113200</td>\n",
1326
+ " </tr>\n",
1327
+ " <tr>\n",
1328
+ " <td>236</td>\n",
1329
+ " <td>1.150700</td>\n",
1330
+ " </tr>\n",
1331
+ " <tr>\n",
1332
+ " <td>237</td>\n",
1333
+ " <td>1.131900</td>\n",
1334
+ " </tr>\n",
1335
+ " <tr>\n",
1336
+ " <td>238</td>\n",
1337
+ " <td>1.107200</td>\n",
1338
+ " </tr>\n",
1339
+ " <tr>\n",
1340
+ " <td>239</td>\n",
1341
+ " <td>1.137600</td>\n",
1342
+ " </tr>\n",
1343
+ " <tr>\n",
1344
+ " <td>240</td>\n",
1345
+ " <td>1.094800</td>\n",
1346
+ " </tr>\n",
1347
+ " <tr>\n",
1348
+ " <td>241</td>\n",
1349
+ " <td>1.068000</td>\n",
1350
+ " </tr>\n",
1351
+ " <tr>\n",
1352
+ " <td>242</td>\n",
1353
+ " <td>1.122100</td>\n",
1354
+ " </tr>\n",
1355
+ " <tr>\n",
1356
+ " <td>243</td>\n",
1357
+ " <td>1.153700</td>\n",
1358
+ " </tr>\n",
1359
+ " <tr>\n",
1360
+ " <td>244</td>\n",
1361
+ " <td>1.045100</td>\n",
1362
+ " </tr>\n",
1363
+ " <tr>\n",
1364
+ " <td>245</td>\n",
1365
+ " <td>1.131400</td>\n",
1366
+ " </tr>\n",
1367
+ " <tr>\n",
1368
+ " <td>246</td>\n",
1369
+ " <td>1.134600</td>\n",
1370
+ " </tr>\n",
1371
+ " <tr>\n",
1372
+ " <td>247</td>\n",
1373
+ " <td>1.105300</td>\n",
1374
+ " </tr>\n",
1375
+ " <tr>\n",
1376
+ " <td>248</td>\n",
1377
+ " <td>1.108800</td>\n",
1378
+ " </tr>\n",
1379
+ " <tr>\n",
1380
+ " <td>249</td>\n",
1381
+ " <td>1.080800</td>\n",
1382
+ " </tr>\n",
1383
+ " <tr>\n",
1384
+ " <td>250</td>\n",
1385
+ " <td>1.119200</td>\n",
1386
+ " </tr>\n",
1387
+ " </tbody>\n",
1388
+ "</table><p>"
1389
+ ],
1390
+ "text/plain": [
1391
+ "<IPython.core.display.HTML object>"
1392
+ ]
1393
+ },
1394
+ "metadata": {},
1395
+ "output_type": "display_data"
1396
+ }
1397
+ ],
1398
+ "source": [
1399
+ "trainer = transformers.Trainer(\n",
1400
+ " model = model,\n",
1401
+ " train_dataset = data,\n",
1402
+ " args = targs,\n",
1403
+ " #data_collator=transformers.DefaultDataCollator\n",
1404
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
1405
+ ")\n",
1406
+ "trainer.train(resume_from_checkpoint=False)\n",
1407
+ "model.save_pretrained('sqllama-out3')"
1408
+ ]
1409
+ },
1410
+ {
1411
+ "cell_type": "code",
1412
+ "execution_count": 11,
1413
+ "metadata": {},
1414
+ "outputs": [
1415
+ {
1416
+ "name": "stdout",
1417
+ "output_type": "stream",
1418
+ "text": [
1419
+ "from model\n",
1420
+ "<unk>table: 2-10806592-14\n",
1421
+ "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n",
1422
+ "Q: What was the away score when the home team was Melbourne?\n",
1423
+ "A: SELECT Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n",
1424
+ "FROM 2-1080\n",
1425
+ "\n",
1426
+ "expected answer\n",
1427
+ "SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n",
1428
+ "END\n",
1429
+ "\n",
1430
+ "from model\n",
1431
+ "<unk>table: 2-17978030-6\n",
1432
+ "columns: Date,Time,Score,Set 1,Set 2,Set 3,Total\n",
1433
+ "Q: What is the score when the set 3 is 26–28?\n",
1434
+ "A: SELECT Score FROM 2-17978030-6 WHERE Set 3 = 26–28\n",
1435
+ "END\n",
1436
+ "\\end{code}\n",
1437
+ "\n",
1438
+ "\n",
1439
+ "expected answer\n",
1440
+ "SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n",
1441
+ "END\n",
1442
+ "\n",
1443
+ "from model\n",
1444
+ "<unk>table: 2-12487101-6\n",
1445
+ "columns: Position,Number,Name,Class,Injury (Status),Last Update\n",
1446
+ "Q: for the position of rb what is the name?\n",
1447
+ "A: SELECT Name FROM 2-12487101-6 WHERE Position = 'rb'\n",
1448
+ "END\n",
1449
+ "\\end{code}\n",
1450
+ "\n",
1451
+ "Comment: I'm not sure what you mean by \"the name of\n",
1452
+ "\n",
1453
+ "expected answer\n",
1454
+ "SELECT Name FROM 2-12487101-6 WHERE Position = 'rb'\n",
1455
+ "END\n",
1456
+ "\n",
1457
+ "from model\n",
1458
+ "<unk>table: 2-16780011-2\n",
1459
+ "columns: Week,Date,Opponent,Result,TV Time,Attendance\n",
1460
+ "Q: Which week had an attendance of 55,158?\n",
1461
+ "A: SELECT Attendance FROM 2-16780011-2 WHERE Week = '2019-01-06'\n",
1462
+ "END\n",
1463
+ "\\end{code}\n",
1464
+ "\n",
1465
+ "\n",
1466
+ "\n",
1467
+ "expected answer\n",
1468
+ "SELECT SUM Week FROM 2-16780011-2 WHERE Attendance = '55,158'\n",
1469
+ "END\n",
1470
+ "\n",
1471
+ "from model\n",
1472
+ "<unk>table: 1-1671401-1\n",
1473
+ "columns: Year,Starts,Wins,Top 5,Top 10,Poles,Avg. Start,Avg. Finish,Winnings,Position,Team(s)\n",
1474
+ "Q: Name the least top 5\n",
1475
+ "A: SELECT MIN Top 5 FROM 1-1671401-1\n",
1476
+ "END\n",
1477
+ "\\end{code}\n",
1478
+ "\n",
1479
+ "\n",
1480
+ "expected answer\n",
1481
+ "SELECT MIN Top 5 FROM 1-1671401-1\n",
1482
+ "END\n",
1483
+ "\n",
1484
+ "from model\n",
1485
+ "<unk>table: 2-17540875-4\n",
1486
+ "columns: Tie no,Home team,Score,Away team,Date\n",
1487
+ "Q: What team was the home team when Manchester City was the away team?\n",
1488
+ "A: SELECT Tie no,Home team FROM 2-17540875-4 WHERE Away team = 'Manchester City'\n",
1489
+ "END\n",
1490
+ "\n",
1491
+ "Comment: @user378326\n",
1492
+ "\n",
1493
+ "expected answer\n",
1494
+ "SELECT Home team FROM 2-17540875-4 WHERE Away team = 'manchester city'\n",
1495
+ "END\n",
1496
+ "\n",
1497
+ "from model\n",
1498
+ "<unk>table: 2-15524351-11\n",
1499
+ "columns: Date,Opponent,Location,Result,Attendance\n",
1500
+ "Q: What team was the opponent when the result was l 0-27?\n",
1501
+ "A: SELECT Opponent FROM 2-15524351-11 WHERE Result = 'l 0-27'\n",
1502
+ "END\n",
1503
+ "\\end{code}\n",
1504
+ "\n",
1505
+ "Comment\n",
1506
+ "\n",
1507
+ "expected answer\n",
1508
+ "SELECT Opponent FROM 2-15524351-11 WHERE Result = 'l 0-27'\n",
1509
+ "END\n",
1510
+ "\n",
1511
+ "from model\n",
1512
+ "<unk>table: 1-12962773-1\n",
1513
+ "columns: No,Player,Height,Position,Year born,Current Club\n",
1514
+ "Q: what's current club with height being 2.09\n",
1515
+ "A: SELECT Current Club\n",
1516
+ "FROM 1-12962773-1\n",
1517
+ "WHERE Height = 2.09\n",
1518
+ "END\n",
1519
+ "\\end{code}\n",
1520
+ "\n",
1521
+ "Comment: I'm\n",
1522
+ "\n",
1523
+ "expected answer\n",
1524
+ "SELECT Current Club FROM 1-12962773-1 WHERE Height = '2.09'\n",
1525
+ "END\n",
1526
+ "\n",
1527
+ "from model\n",
1528
+ "<unk>table: 1-12441518-1\n",
1529
+ "columns: Character,Portrayed by,Main cast seasons,Recurring cast seasons,# of episodes\n",
1530
+ "Q: How many people named Nick Lucas are on the show?\n",
1531
+ "A: SELECT COUNT(Character) FROM 1-12441518-1 WHERE Character = 'Nick Lucas'\n",
1532
+ "END\n",
1533
+ "\n",
1534
+ "Comment: @user3783264,\n",
1535
+ "\n",
1536
+ "expected answer\n",
1537
+ "SELECT COUNT Portrayed by FROM 1-12441518-1 WHERE Character = 'Nick Lucas'\n",
1538
+ "END\n",
1539
+ "\n",
1540
+ "from model\n",
1541
+ "<unk>table: 1-17176509-4\n",
1542
+ "columns: Position,Driver / Passenger,Equipment,Bike No,Points\n",
1543
+ "Q: Name the driver/passenger for 30\n",
1544
+ "A: SELECT Driver / Passenger\n",
1545
+ "FROM 1-17176509-4\n",
1546
+ "WHERE 30 = 'Driver / Passenger'\n",
1547
+ "END\n",
1548
+ "\n",
1549
+ "Comment: I'm not sure what you mean by\n",
1550
+ "\n",
1551
+ "expected answer\n",
1552
+ "SELECT COUNT Driver / Passenger FROM 1-17176509-4 WHERE Position = 30\n",
1553
+ "END\n",
1554
+ "\n"
1555
+ ]
1556
+ }
1557
+ ],
1558
+ "source": [
1559
+ "def get_query(q):\n",
1560
+ " \n",
1561
+ " toks = tokenizer(q , return_tensors='pt')\n",
1562
+ " ctoks = toks.input_ids.to('cuda')\n",
1563
+ " gen = model.generate(ctoks, max_length=100)\n",
1564
+ " return tokenizer.decode(gen[0])\n",
1565
+ "\n",
1566
+ "M = len(q_red)\n",
1567
+ "\n",
1568
+ "for _ in range(10):\n",
1569
+ " j = random.randint(0,M-1)\n",
1570
+ " qs = q_red[j]\n",
1571
+ " a = a_red[j]\n",
1572
+ "\n",
1573
+ " ma = get_query(qs)\n",
1574
+ "\n",
1575
+ " #print(qs)\n",
1576
+ " print('from model')\n",
1577
+ " print(ma)\n",
1578
+ " print()\n",
1579
+ " print('expected answer')\n",
1580
+ " print(a)\n"
1581
+ ]
1582
+ }
1583
+ ],
1584
+ "metadata": {
1585
+ "kernelspec": {
1586
+ "display_name": ".venv",
1587
+ "language": "python",
1588
+ "name": "python3"
1589
+ },
1590
+ "language_info": {
1591
+ "codemirror_mode": {
1592
+ "name": "ipython",
1593
+ "version": 3
1594
+ },
1595
+ "file_extension": ".py",
1596
+ "mimetype": "text/x-python",
1597
+ "name": "python",
1598
+ "nbconvert_exporter": "python",
1599
+ "pygments_lexer": "ipython3",
1600
+ "version": "3.7.3"
1601
+ },
1602
+ "orig_nbformat": 4,
1603
+ "vscode": {
1604
+ "interpreter": {
1605
+ "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90"
1606
+ }
1607
+ }
1608
+ },
1609
+ "nbformat": 4,
1610
+ "nbformat_minor": 2
1611
+ }