matt-tries-dl commited on
Commit
ce91638
1 Parent(s): 7dd7ab4

second pass

Browse files
Files changed (3) hide show
  1. res2.txt +340 -0
  2. sqllama-out2/adapter_model.bin +1 -1
  3. wikisql.ipynb +1254 -125
res2.txt ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Output exceeds the size limit. Open the full output data in a text editor
2
+
3
+ table: 2-1137692-1
4
+ columns: Entrant,Constructor,Chassis,Engine †,Tyre,Driver,Rounds
5
+ Q: What were the rounds on the Engine † of the Ferrari 048?
6
+ A: SELECT Rounds FROM 2-1137692-1 WHERE Engine † = 'ferrari 048'
7
+ END
8
+
9
+
10
+ table: 1-21530474-1
11
+ columns: Chassis code,Model no.,Production years,Drivetrain,Transmission,Engine type,Engine code,Region(s)
12
+ Q: Name the drivetrain for 1ur-fse for usf41
13
+ A: SELECT Drivetrain FROM 1-21530474-1 WHERE Engine code = '1UR-FSE' AND Chassis code = 'USF41'
14
+ END
15
+
16
+
17
+ table: 2-14155087-1
18
+ columns: Callsign,Area served,Frequency,Band,On-air ID,Purpose
19
+ Q: What is the Callsign with an Area of tamworth and frequency of 0 88.9?
20
+ A: SELECT Callsign FROM 2-14155087-1 WHERE Area served = 'tamworth' AND Frequency = '0 88.9'
21
+ END
22
+
23
+
24
+ table: 2-17580726-2
25
+ columns: Date,Opponent,Venue,Score,Attendance,Scorers
26
+ Q: What is the number of people in attendance when Tonbridge Angels is the opponent?
27
+ ...
28
+ Q: What were the match points when Bordeaux-Bègles was eliminated from competition?
29
+ A: SELECT Match points FROM 1-27986200-3 WHERE Eliminated from competition = 'Bordeaux-Bègles'
30
+ END
31
+
32
+ /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...
33
+ warn(msg)
34
+ The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
35
+ The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
36
+ The class this function is called from is 'LlamaTokenizer'.
37
+
38
+ ===================================BUG REPORT===================================
39
+ Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
40
+ ================================================================================
41
+ CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
42
+ CUDA SETUP: Highest compute capability among GPUs detected: 7.5
43
+ CUDA SETUP: Detected CUDA version 113
44
+ CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...
45
+ True
46
+ 92
47
+ 0
48
+ count 56355.000000
49
+ mean 101.219519
50
+ std 21.740325
51
+ min 63.000000
52
+ 25% 87.500000
53
+ 50% 97.000000
54
+ 75% 109.000000
55
+ max 461.000000
56
+ 32084
57
+ [250/250 3:49:26, Epoch 0/1]
58
+ Step Training Loss
59
+ 1 2.748800
60
+ 2 2.699100
61
+ 3 2.670200
62
+ 4 2.600500
63
+ 5 2.560100
64
+ 6 2.556800
65
+ 7 2.498100
66
+ 8 2.515400
67
+ 9 2.436100
68
+ 10 2.411700
69
+ 11 2.346400
70
+ 12 2.276300
71
+ 13 2.238000
72
+ 14 2.189100
73
+ 15 2.109200
74
+ 16 2.058000
75
+ 17 1.983900
76
+ 18 1.928600
77
+ 19 1.824100
78
+ 20 1.794700
79
+ 21 1.681200
80
+ 22 1.598900
81
+ 23 1.562000
82
+ 24 1.527200
83
+ 25 1.518700
84
+ 26 1.493100
85
+ 27 1.500500
86
+ 28 1.464000
87
+ 29 1.386900
88
+ 30 1.373400
89
+ 31 1.362200
90
+ 32 1.360800
91
+ 33 1.321000
92
+ 34 1.310500
93
+ 35 1.302600
94
+ 36 1.256100
95
+ 37 1.252500
96
+ 38 1.202300
97
+ 39 1.249100
98
+ 40 1.188600
99
+ 41 1.203200
100
+ 42 1.150000
101
+ 43 1.182000
102
+ 44 1.192300
103
+ 45 1.133100
104
+ 46 1.119600
105
+ 47 1.097000
106
+ 48 1.142100
107
+ 49 1.117200
108
+ 50 1.129200
109
+ 51 1.087300
110
+ 52 1.098700
111
+ 53 1.135400
112
+ 54 1.071700
113
+ 55 1.087300
114
+ 56 1.051400
115
+ 57 1.068300
116
+ 58 1.092500
117
+ 59 1.068600
118
+ 60 1.072800
119
+ 61 1.074000
120
+ 62 1.060400
121
+ 63 1.065800
122
+ 64 1.075900
123
+ 65 1.059500
124
+ 66 1.039600
125
+ 67 1.051400
126
+ 68 1.049500
127
+ 69 1.023800
128
+ 70 1.071900
129
+ 71 1.051000
130
+ 72 1.034700
131
+ 73 1.041600
132
+ 74 1.030900
133
+ 75 1.010800
134
+ 76 1.019800
135
+ 77 1.005000
136
+ 78 1.043800
137
+ 79 1.009200
138
+ 80 1.017100
139
+ 81 1.044600
140
+ 82 1.022600
141
+ 83 1.011400
142
+ 84 0.996600
143
+ 85 1.029900
144
+ 86 0.988200
145
+ 87 1.005600
146
+ 88 0.986600
147
+ 89 1.025300
148
+ 90 1.012500
149
+ 91 0.988100
150
+ 92 1.001800
151
+ 93 0.987100
152
+ 94 1.017600
153
+ 95 0.998500
154
+ 96 0.966600
155
+ 97 0.983700
156
+ 98 0.961800
157
+ 99 0.969000
158
+ 100 0.989200
159
+ 101 0.956400
160
+ 102 0.976000
161
+ 103 1.000100
162
+ 104 1.001500
163
+ 105 0.995900
164
+ 106 0.989700
165
+ 107 0.965700
166
+ 108 0.968400
167
+ 109 1.019600
168
+ 110 1.000100
169
+ 111 0.978500
170
+ 112 0.978900
171
+ 113 0.952600
172
+ 114 0.975400
173
+ 115 0.989400
174
+ 116 0.968500
175
+ 117 0.960100
176
+ 118 0.979100
177
+ 119 0.955100
178
+ 120 0.934800
179
+ 121 0.943600
180
+ 122 0.976700
181
+ 123 0.998700
182
+ 124 0.930500
183
+ 125 0.953500
184
+ 126 0.978000
185
+ 127 0.967300
186
+ 128 0.929400
187
+ 129 0.963100
188
+ 130 0.961500
189
+ 131 0.978500
190
+ 132 0.937200
191
+ 133 0.953400
192
+ 134 0.962000
193
+ 135 0.950700
194
+ 136 0.925100
195
+ 137 0.958800
196
+ 138 0.926200
197
+ 139 0.930600
198
+ 140 0.968900
199
+ 141 0.970400
200
+ 142 0.927100
201
+ 143 0.911800
202
+ 144 0.953200
203
+ 145 0.907100
204
+ 146 0.935900
205
+ 147 0.970600
206
+ 148 0.920400
207
+ 149 0.930200
208
+ 150 0.926700
209
+ 151 0.913400
210
+ 152 0.926800
211
+ 153 0.967200
212
+ 154 0.939500
213
+ 155 0.910600
214
+ 156 0.926400
215
+ 157 0.935400
216
+ 158 0.967700
217
+ 159 0.899000
218
+ 160 0.916600
219
+ 161 0.961600
220
+ 162 0.898200
221
+ 163 0.944600
222
+ 164 0.935700
223
+ 165 0.922500
224
+ 166 0.897600
225
+ 167 0.968600
226
+ 168 0.927400
227
+ 169 0.910900
228
+ 170 0.904700
229
+ 171 0.899800
230
+ 172 0.896400
231
+ 173 0.862100
232
+ 174 0.909100
233
+ 175 0.903200
234
+ 176 0.958600
235
+ 177 0.902500
236
+ 178 0.894900
237
+ 179 0.937900
238
+ 180 0.900700
239
+ 181 0.922300
240
+ 182 0.939300
241
+ 183 0.932600
242
+ 184 0.913300
243
+ 185 0.941700
244
+ 186 0.886300
245
+ 187 0.918000
246
+ 188 0.884000
247
+ 189 0.947400
248
+ 190 0.894500
249
+ 191 0.929300
250
+ 192 0.877300
251
+ 193 0.894300
252
+ 194 0.867800
253
+ 195 0.913500
254
+ 196 0.908100
255
+ 197 0.931200
256
+ 198 0.911000
257
+ 199 0.941800
258
+ 200 0.913000
259
+ 201 0.921800
260
+ 202 0.921700
261
+ 203 0.914500
262
+ 204 0.910500
263
+ 205 0.906600
264
+ 206 0.915100
265
+ 207 0.881600
266
+ 208 0.884700
267
+ 209 0.902900
268
+ 210 0.882600
269
+ 211 0.891000
270
+ 212 0.914400
271
+ 213 0.930400
272
+ 214 0.891100
273
+ 215 0.859300
274
+ 216 0.891800
275
+ 217 0.873000
276
+ 218 0.925900
277
+ 219 0.905700
278
+ 220 0.921200
279
+ 221 0.890200
280
+ 222 0.915800
281
+ 223 0.887300
282
+ 224 0.898300
283
+ 225 0.865600
284
+ 226 0.873900
285
+ 227 0.904800
286
+ 228 0.917900
287
+ 229 0.923400
288
+ 230 0.939700
289
+ 231 0.913400
290
+ 232 0.873100
291
+ 233 0.896700
292
+ 234 0.892100
293
+ 235 0.902100
294
+ 236 0.927200
295
+ 237 0.912900
296
+ 238 0.872900
297
+ 239 0.904700
298
+ 240 0.879600
299
+ 241 0.879800
300
+ 242 0.908800
301
+ 243 0.909800
302
+ 244 0.838400
303
+ 245 0.889200
304
+ 246 0.912900
305
+ 247 0.879700
306
+ 248 0.910700
307
+ 249 0.845400
308
+ 250 0.882200
309
+ /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
310
+ warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
311
+ Output exceeds the size limit. Open the full output data in a text editor
312
+ from model
313
+ <unk>table: 1-12028543-3
314
+ columns: Season,Cup FinalDate,WinningTeam,Score,LosingTeam,Location,Cup Final Attendance
315
+ Q: Who was the winning team in the 1989 season?
316
+ A: SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'
317
+ END
318
+ END
319
+ END
320
+ END
321
+
322
+ expected answer
323
+ SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'
324
+ END
325
+
326
+ from model
327
+ <unk>table: 2-18096431-5
328
+ columns: Place,Player,Country,Score,To par
329
+ Q: What is To par, when Country is "United States", and when Player is "Mark Brooks"?
330
+ A: 18-1
331
+ END
332
+
333
+
334
+ expected answer
335
+ SELECT To par FROM 2-18096431-5 WHERE Country = 'united states' AND Player = 'mark brooks'
336
+ END
337
+ ...
338
+ expected answer
339
+ SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'
340
+ END
sqllama-out2/adapter_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ee15525f45ab11e3e7ba334c0639b7263ea25ae0d42aa22f801022020ffc493
3
  size 8434381
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:423e1095c473c661c7ce62b23e3c06bb80780c572177c2ab77ef1451117cf83d
3
  size 8434381
wikisql.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 11,
6
  "metadata": {},
7
  "outputs": [
8
  {
@@ -11,7 +11,7 @@
11
  "True"
12
  ]
13
  },
14
- "execution_count": 11,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
@@ -24,7 +24,7 @@
24
  },
25
  {
26
  "cell_type": "code",
27
- "execution_count": null,
28
  "metadata": {},
29
  "outputs": [
30
  {
@@ -55,7 +55,7 @@
55
  {
56
  "data": {
57
  "application/vnd.jupyter.widget-view+json": {
58
- "model_id": "a9428ee09f334655b6b261d478cbd3d0",
59
  "version_major": 2,
60
  "version_minor": 0
61
  },
@@ -96,30 +96,40 @@
96
  "output_type": "stream",
97
  "text": [
98
  "\n",
99
- "table: 2-13081928-2\n",
100
- "columns: Country,Chart,Period,Peak position,Sales\n",
101
- "Q: Name the period for Chart of g-music j-pop/k-pop chart\n",
102
- "A: SELECT Period FROM 2-13081928-2 WHERE Chart = 'g-music j-pop/k-pop chart'\n",
103
- "\n",
104
- "table: 2-13612447-1\n",
105
- "columns: Fraction,Ellipsis,Vinculum,Dots,Parentheses\n",
106
- "Q: What is the dot value when the ellipsis is 0.012345679…?\n",
107
- "A: SELECT Dots FROM 2-13612447-1 WHERE Ellipsis = '0.012345679…'\n",
108
- "\n",
109
- "table: 1-168274-1\n",
110
- "columns: Company,ICB Sector,Ticker symbol,Index weighting (%) at 17 January 2013,Market cap. at April 2013 (€)\n",
111
- "Q: Name the total number of index weighting % at 17 january 2013 for bouygues\n",
112
- "A: SELECT COUNT Index weighting (%) at 17 January 2013 FROM 1-168274-1 WHERE Company = 'Bouygues'\n",
113
- "\n",
114
- "table: 2-15826191-2\n",
115
- "columns: Rank,Nation,Gold,Silver,Bronze,Total\n",
116
- "Q: What is the lowest gold when there are 0 bronze and the total is less than 2, and silver is less than 0?\n",
117
- "A: SELECT MIN Gold FROM 2-15826191-2 WHERE Bronze = 0 AND Total < 2 AND Silver < 0\n",
118
- "\n",
119
- "table: 2-16387912-1\n",
120
- "columns: Home team,Home team score,Away team,Away team score,Ground,Date,Time\n",
121
- "Q: What is Ground, when Away Team is Sydney?\n",
122
- "A: SELECT Ground FROM 2-16387912-1 WHERE Away team = 'sydney'\n"
 
 
 
 
 
 
 
 
 
 
123
  ]
124
  }
125
  ],
@@ -228,17 +238,17 @@
228
  "name": "stdout",
229
  "output_type": "stream",
230
  "text": [
231
- "89\n",
232
  " 0\n",
233
  "count 56355.000000\n",
234
- "mean 98.219519\n",
235
  "std 21.740325\n",
236
- "min 60.000000\n",
237
- "25% 84.500000\n",
238
- "50% 94.000000\n",
239
- "75% 106.000000\n",
240
- "max 458.000000\n",
241
- "35608\n"
242
  ]
243
  }
244
  ],
@@ -262,18 +272,18 @@
262
  },
263
  {
264
  "cell_type": "code",
265
- "execution_count": 7,
266
  "metadata": {},
267
  "outputs": [
268
  {
269
  "data": {
270
  "application/vnd.jupyter.widget-view+json": {
271
- "model_id": "d548eb2af20f435fa1af81e9045a2d0e",
272
  "version_major": 2,
273
  "version_minor": 0
274
  },
275
  "text/plain": [
276
- "Map: 0%| | 0/1000 [00:00<?, ? examples/s]"
277
  ]
278
  },
279
  "metadata": {},
@@ -282,9 +292,8 @@
282
  ],
283
  "source": [
284
  "import random, datasets\n",
285
- "d = {'prompt': random.sample(data_red, 1000)}\n",
286
- "\n",
287
- "tokenizer.pad_token_id = tokenizer.eos_token\n",
288
  "\n",
289
  "data = datasets.Dataset.from_dict(d)\n",
290
  "data = data.map(lambda x:\n",
@@ -300,7 +309,7 @@
300
  },
301
  {
302
  "cell_type": "code",
303
- "execution_count": 8,
304
  "metadata": {},
305
  "outputs": [],
306
  "source": [
@@ -344,7 +353,7 @@
344
  },
345
  {
346
  "cell_type": "code",
347
- "execution_count": 9,
348
  "metadata": {},
349
  "outputs": [
350
  {
@@ -353,8 +362,8 @@
353
  "\n",
354
  " <div>\n",
355
  " \n",
356
- " <progress value='7' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
357
- " [7/7 05:33, Epoch 0/1]\n",
358
  " </div>\n",
359
  " <table border=\"1\" class=\"dataframe\">\n",
360
  " <thead>\n",
@@ -366,104 +375,1224 @@
366
  " <tbody>\n",
367
  " <tr>\n",
368
  " <td>1</td>\n",
369
- " <td>2.710700</td>\n",
370
  " </tr>\n",
371
  " <tr>\n",
372
  " <td>2</td>\n",
373
- " <td>2.680400</td>\n",
374
  " </tr>\n",
375
  " <tr>\n",
376
  " <td>3</td>\n",
377
- " <td>2.684500</td>\n",
378
  " </tr>\n",
379
  " <tr>\n",
380
  " <td>4</td>\n",
381
- " <td>2.625600</td>\n",
382
  " </tr>\n",
383
  " <tr>\n",
384
  " <td>5</td>\n",
385
- " <td>2.609600</td>\n",
386
  " </tr>\n",
387
  " <tr>\n",
388
  " <td>6</td>\n",
389
- " <td>2.619100</td>\n",
390
  " </tr>\n",
391
  " <tr>\n",
392
  " <td>7</td>\n",
393
- " <td>2.603800</td>\n",
394
  " </tr>\n",
395
- " </tbody>\n",
396
- "</table><p>"
397
- ],
398
- "text/plain": [
399
- "<IPython.core.display.HTML object>"
400
- ]
401
- },
402
- "metadata": {},
403
- "output_type": "display_data"
404
- }
405
- ],
406
- "source": [
407
- "trainer = transformers.Trainer(\n",
408
- " model = model,\n",
409
- " train_dataset = data,\n",
410
- " args = targs,\n",
411
- " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
412
- ")\n",
413
- "trainer.train(resume_from_checkpoint=False)\n",
414
- "model.save_pretrained('sqllama-out2')"
415
- ]
416
- },
417
- {
418
- "cell_type": "code",
419
- "execution_count": 10,
420
- "metadata": {},
421
- "outputs": [
422
- {
423
- "name": "stderr",
424
- "output_type": "stream",
425
- "text": [
426
- "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
427
- " \"You have modified the pretrained model configuration to control generation. This is a\"\n",
428
- "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
429
- " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
430
- ]
431
- },
432
- {
433
- "name": "stdout",
434
- "output_type": "stream",
435
- "text": [
436
- "from model\n",
437
- " ⁇ table: 1-25800134-1\n",
438
- "columns: Series #,Season #,Title,Director,Writer(s),Airdate\n",
439
- "Q: Who wrote the episode with series number 56?\n",
440
- "A: 56-101, \"The Cage\", Gene Roddenberry\n",
441
- "Q: Who wrote the episode with series number 56? (2)\n",
442
- "A: 56-101,\n",
443
- "expected answer SELECT Writer(s) FROM 1-25800134-1 WHERE Series # = 56\n"
444
- ]
445
- }
446
- ],
447
- "source": [
448
- "def get_query(q):\n",
449
- " \n",
450
- " toks = tokenizer(q , return_tensors='pt')\n",
451
- " ctoks = toks.input_ids.to('cuda')\n",
452
- " gen = model.generate(ctoks, max_length=100)\n",
453
- " return tokenizer.decode(gen[0])\n",
454
- "\n",
455
- "M = len(q_red)\n",
456
- "j = random.randint(0,M-1)\n",
457
- "qs = q_red[j]\n",
458
- "a = a_red[j]\n",
459
- "\n",
460
- "ma = get_query(qs)\n",
461
- "\n",
462
- "#print(qs)\n",
463
- "print('from model')\n",
464
- "print(ma)\n",
465
- "print\n",
466
- "print('expected answer',a)\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  ]
468
  }
469
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [
8
  {
 
11
  "True"
12
  ]
13
  },
14
+ "execution_count": 1,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
 
24
  },
25
  {
26
  "cell_type": "code",
27
+ "execution_count": 2,
28
  "metadata": {},
29
  "outputs": [
30
  {
 
55
  {
56
  "data": {
57
  "application/vnd.jupyter.widget-view+json": {
58
+ "model_id": "0b970b9989854c33aa4d19fa9457d7a2",
59
  "version_major": 2,
60
  "version_minor": 0
61
  },
 
96
  "output_type": "stream",
97
  "text": [
98
  "\n",
99
+ "table: 2-1137692-1\n",
100
+ "columns: Entrant,Constructor,Chassis,Engine †,Tyre,Driver,Rounds\n",
101
+ "Q: What were the rounds on the Engine † of the Ferrari 048?\n",
102
+ "A: SELECT Rounds FROM 2-1137692-1 WHERE Engine = 'ferrari 048'\n",
103
+ "END\n",
104
+ "\n",
105
+ "\n",
106
+ "table: 1-21530474-1\n",
107
+ "columns: Chassis code,Model no.,Production years,Drivetrain,Transmission,Engine type,Engine code,Region(s)\n",
108
+ "Q: Name the drivetrain for 1ur-fse for usf41\n",
109
+ "A: SELECT Drivetrain FROM 1-21530474-1 WHERE Engine code = '1UR-FSE' AND Chassis code = 'USF41'\n",
110
+ "END\n",
111
+ "\n",
112
+ "\n",
113
+ "table: 2-14155087-1\n",
114
+ "columns: Callsign,Area served,Frequency,Band,On-air ID,Purpose\n",
115
+ "Q: What is the Callsign with an Area of tamworth and frequency of 0 88.9?\n",
116
+ "A: SELECT Callsign FROM 2-14155087-1 WHERE Area served = 'tamworth' AND Frequency = '0 88.9'\n",
117
+ "END\n",
118
+ "\n",
119
+ "\n",
120
+ "table: 2-17580726-2\n",
121
+ "columns: Date,Opponent,Venue,Score,Attendance,Scorers\n",
122
+ "Q: What is the number of people in attendance when Tonbridge Angels is the opponent?\n",
123
+ "A: SELECT Attendance FROM 2-17580726-2 WHERE Opponent = 'tonbridge angels'\n",
124
+ "END\n",
125
+ "\n",
126
+ "\n",
127
+ "table: 1-27986200-3\n",
128
+ "columns: Proceed to Quarter-final,Match points,Aggregate score,Points margin,Eliminated from competition\n",
129
+ "Q: What were the match points when Bordeaux-Bègles was eliminated from competition? \n",
130
+ "A: SELECT Match points FROM 1-27986200-3 WHERE Eliminated from competition = 'Bordeaux-Bègles'\n",
131
+ "END\n",
132
+ "\n"
133
  ]
134
  }
135
  ],
 
238
  "name": "stdout",
239
  "output_type": "stream",
240
  "text": [
241
+ "92\n",
242
  " 0\n",
243
  "count 56355.000000\n",
244
+ "mean 101.219519\n",
245
  "std 21.740325\n",
246
+ "min 63.000000\n",
247
+ "25% 87.500000\n",
248
+ "50% 97.000000\n",
249
+ "75% 109.000000\n",
250
+ "max 461.000000\n",
251
+ "32084\n"
252
  ]
253
  }
254
  ],
 
272
  },
273
  {
274
  "cell_type": "code",
275
+ "execution_count": 6,
276
  "metadata": {},
277
  "outputs": [
278
  {
279
  "data": {
280
  "application/vnd.jupyter.widget-view+json": {
281
+ "model_id": "01debeeb68bb40a7b83031e88c5ace1e",
282
  "version_major": 2,
283
  "version_minor": 0
284
  },
285
  "text/plain": [
286
+ "Map: 0%| | 0/32084 [00:00<?, ? examples/s]"
287
  ]
288
  },
289
  "metadata": {},
 
292
  ],
293
  "source": [
294
  "import random, datasets\n",
295
+ "#d = {'prompt': random.sample(data_red, 1000)}\n",
296
+ "d = {'prompt': data_red}\n",
 
297
  "\n",
298
  "data = datasets.Dataset.from_dict(d)\n",
299
  "data = data.map(lambda x:\n",
 
309
  },
310
  {
311
  "cell_type": "code",
312
+ "execution_count": 7,
313
  "metadata": {},
314
  "outputs": [],
315
  "source": [
 
353
  },
354
  {
355
  "cell_type": "code",
356
+ "execution_count": 8,
357
  "metadata": {},
358
  "outputs": [
359
  {
 
362
  "\n",
363
  " <div>\n",
364
  " \n",
365
+ " <progress value='250' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
366
+ " [250/250 3:49:26, Epoch 0/1]\n",
367
  " </div>\n",
368
  " <table border=\"1\" class=\"dataframe\">\n",
369
  " <thead>\n",
 
375
  " <tbody>\n",
376
  " <tr>\n",
377
  " <td>1</td>\n",
378
+ " <td>2.748800</td>\n",
379
  " </tr>\n",
380
  " <tr>\n",
381
  " <td>2</td>\n",
382
+ " <td>2.699100</td>\n",
383
  " </tr>\n",
384
  " <tr>\n",
385
  " <td>3</td>\n",
386
+ " <td>2.670200</td>\n",
387
  " </tr>\n",
388
  " <tr>\n",
389
  " <td>4</td>\n",
390
+ " <td>2.600500</td>\n",
391
  " </tr>\n",
392
  " <tr>\n",
393
  " <td>5</td>\n",
394
+ " <td>2.560100</td>\n",
395
  " </tr>\n",
396
  " <tr>\n",
397
  " <td>6</td>\n",
398
+ " <td>2.556800</td>\n",
399
  " </tr>\n",
400
  " <tr>\n",
401
  " <td>7</td>\n",
402
+ " <td>2.498100</td>\n",
403
  " </tr>\n",
404
+ " <tr>\n",
405
+ " <td>8</td>\n",
406
+ " <td>2.515400</td>\n",
407
+ " </tr>\n",
408
+ " <tr>\n",
409
+ " <td>9</td>\n",
410
+ " <td>2.436100</td>\n",
411
+ " </tr>\n",
412
+ " <tr>\n",
413
+ " <td>10</td>\n",
414
+ " <td>2.411700</td>\n",
415
+ " </tr>\n",
416
+ " <tr>\n",
417
+ " <td>11</td>\n",
418
+ " <td>2.346400</td>\n",
419
+ " </tr>\n",
420
+ " <tr>\n",
421
+ " <td>12</td>\n",
422
+ " <td>2.276300</td>\n",
423
+ " </tr>\n",
424
+ " <tr>\n",
425
+ " <td>13</td>\n",
426
+ " <td>2.238000</td>\n",
427
+ " </tr>\n",
428
+ " <tr>\n",
429
+ " <td>14</td>\n",
430
+ " <td>2.189100</td>\n",
431
+ " </tr>\n",
432
+ " <tr>\n",
433
+ " <td>15</td>\n",
434
+ " <td>2.109200</td>\n",
435
+ " </tr>\n",
436
+ " <tr>\n",
437
+ " <td>16</td>\n",
438
+ " <td>2.058000</td>\n",
439
+ " </tr>\n",
440
+ " <tr>\n",
441
+ " <td>17</td>\n",
442
+ " <td>1.983900</td>\n",
443
+ " </tr>\n",
444
+ " <tr>\n",
445
+ " <td>18</td>\n",
446
+ " <td>1.928600</td>\n",
447
+ " </tr>\n",
448
+ " <tr>\n",
449
+ " <td>19</td>\n",
450
+ " <td>1.824100</td>\n",
451
+ " </tr>\n",
452
+ " <tr>\n",
453
+ " <td>20</td>\n",
454
+ " <td>1.794700</td>\n",
455
+ " </tr>\n",
456
+ " <tr>\n",
457
+ " <td>21</td>\n",
458
+ " <td>1.681200</td>\n",
459
+ " </tr>\n",
460
+ " <tr>\n",
461
+ " <td>22</td>\n",
462
+ " <td>1.598900</td>\n",
463
+ " </tr>\n",
464
+ " <tr>\n",
465
+ " <td>23</td>\n",
466
+ " <td>1.562000</td>\n",
467
+ " </tr>\n",
468
+ " <tr>\n",
469
+ " <td>24</td>\n",
470
+ " <td>1.527200</td>\n",
471
+ " </tr>\n",
472
+ " <tr>\n",
473
+ " <td>25</td>\n",
474
+ " <td>1.518700</td>\n",
475
+ " </tr>\n",
476
+ " <tr>\n",
477
+ " <td>26</td>\n",
478
+ " <td>1.493100</td>\n",
479
+ " </tr>\n",
480
+ " <tr>\n",
481
+ " <td>27</td>\n",
482
+ " <td>1.500500</td>\n",
483
+ " </tr>\n",
484
+ " <tr>\n",
485
+ " <td>28</td>\n",
486
+ " <td>1.464000</td>\n",
487
+ " </tr>\n",
488
+ " <tr>\n",
489
+ " <td>29</td>\n",
490
+ " <td>1.386900</td>\n",
491
+ " </tr>\n",
492
+ " <tr>\n",
493
+ " <td>30</td>\n",
494
+ " <td>1.373400</td>\n",
495
+ " </tr>\n",
496
+ " <tr>\n",
497
+ " <td>31</td>\n",
498
+ " <td>1.362200</td>\n",
499
+ " </tr>\n",
500
+ " <tr>\n",
501
+ " <td>32</td>\n",
502
+ " <td>1.360800</td>\n",
503
+ " </tr>\n",
504
+ " <tr>\n",
505
+ " <td>33</td>\n",
506
+ " <td>1.321000</td>\n",
507
+ " </tr>\n",
508
+ " <tr>\n",
509
+ " <td>34</td>\n",
510
+ " <td>1.310500</td>\n",
511
+ " </tr>\n",
512
+ " <tr>\n",
513
+ " <td>35</td>\n",
514
+ " <td>1.302600</td>\n",
515
+ " </tr>\n",
516
+ " <tr>\n",
517
+ " <td>36</td>\n",
518
+ " <td>1.256100</td>\n",
519
+ " </tr>\n",
520
+ " <tr>\n",
521
+ " <td>37</td>\n",
522
+ " <td>1.252500</td>\n",
523
+ " </tr>\n",
524
+ " <tr>\n",
525
+ " <td>38</td>\n",
526
+ " <td>1.202300</td>\n",
527
+ " </tr>\n",
528
+ " <tr>\n",
529
+ " <td>39</td>\n",
530
+ " <td>1.249100</td>\n",
531
+ " </tr>\n",
532
+ " <tr>\n",
533
+ " <td>40</td>\n",
534
+ " <td>1.188600</td>\n",
535
+ " </tr>\n",
536
+ " <tr>\n",
537
+ " <td>41</td>\n",
538
+ " <td>1.203200</td>\n",
539
+ " </tr>\n",
540
+ " <tr>\n",
541
+ " <td>42</td>\n",
542
+ " <td>1.150000</td>\n",
543
+ " </tr>\n",
544
+ " <tr>\n",
545
+ " <td>43</td>\n",
546
+ " <td>1.182000</td>\n",
547
+ " </tr>\n",
548
+ " <tr>\n",
549
+ " <td>44</td>\n",
550
+ " <td>1.192300</td>\n",
551
+ " </tr>\n",
552
+ " <tr>\n",
553
+ " <td>45</td>\n",
554
+ " <td>1.133100</td>\n",
555
+ " </tr>\n",
556
+ " <tr>\n",
557
+ " <td>46</td>\n",
558
+ " <td>1.119600</td>\n",
559
+ " </tr>\n",
560
+ " <tr>\n",
561
+ " <td>47</td>\n",
562
+ " <td>1.097000</td>\n",
563
+ " </tr>\n",
564
+ " <tr>\n",
565
+ " <td>48</td>\n",
566
+ " <td>1.142100</td>\n",
567
+ " </tr>\n",
568
+ " <tr>\n",
569
+ " <td>49</td>\n",
570
+ " <td>1.117200</td>\n",
571
+ " </tr>\n",
572
+ " <tr>\n",
573
+ " <td>50</td>\n",
574
+ " <td>1.129200</td>\n",
575
+ " </tr>\n",
576
+ " <tr>\n",
577
+ " <td>51</td>\n",
578
+ " <td>1.087300</td>\n",
579
+ " </tr>\n",
580
+ " <tr>\n",
581
+ " <td>52</td>\n",
582
+ " <td>1.098700</td>\n",
583
+ " </tr>\n",
584
+ " <tr>\n",
585
+ " <td>53</td>\n",
586
+ " <td>1.135400</td>\n",
587
+ " </tr>\n",
588
+ " <tr>\n",
589
+ " <td>54</td>\n",
590
+ " <td>1.071700</td>\n",
591
+ " </tr>\n",
592
+ " <tr>\n",
593
+ " <td>55</td>\n",
594
+ " <td>1.087300</td>\n",
595
+ " </tr>\n",
596
+ " <tr>\n",
597
+ " <td>56</td>\n",
598
+ " <td>1.051400</td>\n",
599
+ " </tr>\n",
600
+ " <tr>\n",
601
+ " <td>57</td>\n",
602
+ " <td>1.068300</td>\n",
603
+ " </tr>\n",
604
+ " <tr>\n",
605
+ " <td>58</td>\n",
606
+ " <td>1.092500</td>\n",
607
+ " </tr>\n",
608
+ " <tr>\n",
609
+ " <td>59</td>\n",
610
+ " <td>1.068600</td>\n",
611
+ " </tr>\n",
612
+ " <tr>\n",
613
+ " <td>60</td>\n",
614
+ " <td>1.072800</td>\n",
615
+ " </tr>\n",
616
+ " <tr>\n",
617
+ " <td>61</td>\n",
618
+ " <td>1.074000</td>\n",
619
+ " </tr>\n",
620
+ " <tr>\n",
621
+ " <td>62</td>\n",
622
+ " <td>1.060400</td>\n",
623
+ " </tr>\n",
624
+ " <tr>\n",
625
+ " <td>63</td>\n",
626
+ " <td>1.065800</td>\n",
627
+ " </tr>\n",
628
+ " <tr>\n",
629
+ " <td>64</td>\n",
630
+ " <td>1.075900</td>\n",
631
+ " </tr>\n",
632
+ " <tr>\n",
633
+ " <td>65</td>\n",
634
+ " <td>1.059500</td>\n",
635
+ " </tr>\n",
636
+ " <tr>\n",
637
+ " <td>66</td>\n",
638
+ " <td>1.039600</td>\n",
639
+ " </tr>\n",
640
+ " <tr>\n",
641
+ " <td>67</td>\n",
642
+ " <td>1.051400</td>\n",
643
+ " </tr>\n",
644
+ " <tr>\n",
645
+ " <td>68</td>\n",
646
+ " <td>1.049500</td>\n",
647
+ " </tr>\n",
648
+ " <tr>\n",
649
+ " <td>69</td>\n",
650
+ " <td>1.023800</td>\n",
651
+ " </tr>\n",
652
+ " <tr>\n",
653
+ " <td>70</td>\n",
654
+ " <td>1.071900</td>\n",
655
+ " </tr>\n",
656
+ " <tr>\n",
657
+ " <td>71</td>\n",
658
+ " <td>1.051000</td>\n",
659
+ " </tr>\n",
660
+ " <tr>\n",
661
+ " <td>72</td>\n",
662
+ " <td>1.034700</td>\n",
663
+ " </tr>\n",
664
+ " <tr>\n",
665
+ " <td>73</td>\n",
666
+ " <td>1.041600</td>\n",
667
+ " </tr>\n",
668
+ " <tr>\n",
669
+ " <td>74</td>\n",
670
+ " <td>1.030900</td>\n",
671
+ " </tr>\n",
672
+ " <tr>\n",
673
+ " <td>75</td>\n",
674
+ " <td>1.010800</td>\n",
675
+ " </tr>\n",
676
+ " <tr>\n",
677
+ " <td>76</td>\n",
678
+ " <td>1.019800</td>\n",
679
+ " </tr>\n",
680
+ " <tr>\n",
681
+ " <td>77</td>\n",
682
+ " <td>1.005000</td>\n",
683
+ " </tr>\n",
684
+ " <tr>\n",
685
+ " <td>78</td>\n",
686
+ " <td>1.043800</td>\n",
687
+ " </tr>\n",
688
+ " <tr>\n",
689
+ " <td>79</td>\n",
690
+ " <td>1.009200</td>\n",
691
+ " </tr>\n",
692
+ " <tr>\n",
693
+ " <td>80</td>\n",
694
+ " <td>1.017100</td>\n",
695
+ " </tr>\n",
696
+ " <tr>\n",
697
+ " <td>81</td>\n",
698
+ " <td>1.044600</td>\n",
699
+ " </tr>\n",
700
+ " <tr>\n",
701
+ " <td>82</td>\n",
702
+ " <td>1.022600</td>\n",
703
+ " </tr>\n",
704
+ " <tr>\n",
705
+ " <td>83</td>\n",
706
+ " <td>1.011400</td>\n",
707
+ " </tr>\n",
708
+ " <tr>\n",
709
+ " <td>84</td>\n",
710
+ " <td>0.996600</td>\n",
711
+ " </tr>\n",
712
+ " <tr>\n",
713
+ " <td>85</td>\n",
714
+ " <td>1.029900</td>\n",
715
+ " </tr>\n",
716
+ " <tr>\n",
717
+ " <td>86</td>\n",
718
+ " <td>0.988200</td>\n",
719
+ " </tr>\n",
720
+ " <tr>\n",
721
+ " <td>87</td>\n",
722
+ " <td>1.005600</td>\n",
723
+ " </tr>\n",
724
+ " <tr>\n",
725
+ " <td>88</td>\n",
726
+ " <td>0.986600</td>\n",
727
+ " </tr>\n",
728
+ " <tr>\n",
729
+ " <td>89</td>\n",
730
+ " <td>1.025300</td>\n",
731
+ " </tr>\n",
732
+ " <tr>\n",
733
+ " <td>90</td>\n",
734
+ " <td>1.012500</td>\n",
735
+ " </tr>\n",
736
+ " <tr>\n",
737
+ " <td>91</td>\n",
738
+ " <td>0.988100</td>\n",
739
+ " </tr>\n",
740
+ " <tr>\n",
741
+ " <td>92</td>\n",
742
+ " <td>1.001800</td>\n",
743
+ " </tr>\n",
744
+ " <tr>\n",
745
+ " <td>93</td>\n",
746
+ " <td>0.987100</td>\n",
747
+ " </tr>\n",
748
+ " <tr>\n",
749
+ " <td>94</td>\n",
750
+ " <td>1.017600</td>\n",
751
+ " </tr>\n",
752
+ " <tr>\n",
753
+ " <td>95</td>\n",
754
+ " <td>0.998500</td>\n",
755
+ " </tr>\n",
756
+ " <tr>\n",
757
+ " <td>96</td>\n",
758
+ " <td>0.966600</td>\n",
759
+ " </tr>\n",
760
+ " <tr>\n",
761
+ " <td>97</td>\n",
762
+ " <td>0.983700</td>\n",
763
+ " </tr>\n",
764
+ " <tr>\n",
765
+ " <td>98</td>\n",
766
+ " <td>0.961800</td>\n",
767
+ " </tr>\n",
768
+ " <tr>\n",
769
+ " <td>99</td>\n",
770
+ " <td>0.969000</td>\n",
771
+ " </tr>\n",
772
+ " <tr>\n",
773
+ " <td>100</td>\n",
774
+ " <td>0.989200</td>\n",
775
+ " </tr>\n",
776
+ " <tr>\n",
777
+ " <td>101</td>\n",
778
+ " <td>0.956400</td>\n",
779
+ " </tr>\n",
780
+ " <tr>\n",
781
+ " <td>102</td>\n",
782
+ " <td>0.976000</td>\n",
783
+ " </tr>\n",
784
+ " <tr>\n",
785
+ " <td>103</td>\n",
786
+ " <td>1.000100</td>\n",
787
+ " </tr>\n",
788
+ " <tr>\n",
789
+ " <td>104</td>\n",
790
+ " <td>1.001500</td>\n",
791
+ " </tr>\n",
792
+ " <tr>\n",
793
+ " <td>105</td>\n",
794
+ " <td>0.995900</td>\n",
795
+ " </tr>\n",
796
+ " <tr>\n",
797
+ " <td>106</td>\n",
798
+ " <td>0.989700</td>\n",
799
+ " </tr>\n",
800
+ " <tr>\n",
801
+ " <td>107</td>\n",
802
+ " <td>0.965700</td>\n",
803
+ " </tr>\n",
804
+ " <tr>\n",
805
+ " <td>108</td>\n",
806
+ " <td>0.968400</td>\n",
807
+ " </tr>\n",
808
+ " <tr>\n",
809
+ " <td>109</td>\n",
810
+ " <td>1.019600</td>\n",
811
+ " </tr>\n",
812
+ " <tr>\n",
813
+ " <td>110</td>\n",
814
+ " <td>1.000100</td>\n",
815
+ " </tr>\n",
816
+ " <tr>\n",
817
+ " <td>111</td>\n",
818
+ " <td>0.978500</td>\n",
819
+ " </tr>\n",
820
+ " <tr>\n",
821
+ " <td>112</td>\n",
822
+ " <td>0.978900</td>\n",
823
+ " </tr>\n",
824
+ " <tr>\n",
825
+ " <td>113</td>\n",
826
+ " <td>0.952600</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <td>114</td>\n",
830
+ " <td>0.975400</td>\n",
831
+ " </tr>\n",
832
+ " <tr>\n",
833
+ " <td>115</td>\n",
834
+ " <td>0.989400</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>116</td>\n",
838
+ " <td>0.968500</td>\n",
839
+ " </tr>\n",
840
+ " <tr>\n",
841
+ " <td>117</td>\n",
842
+ " <td>0.960100</td>\n",
843
+ " </tr>\n",
844
+ " <tr>\n",
845
+ " <td>118</td>\n",
846
+ " <td>0.979100</td>\n",
847
+ " </tr>\n",
848
+ " <tr>\n",
849
+ " <td>119</td>\n",
850
+ " <td>0.955100</td>\n",
851
+ " </tr>\n",
852
+ " <tr>\n",
853
+ " <td>120</td>\n",
854
+ " <td>0.934800</td>\n",
855
+ " </tr>\n",
856
+ " <tr>\n",
857
+ " <td>121</td>\n",
858
+ " <td>0.943600</td>\n",
859
+ " </tr>\n",
860
+ " <tr>\n",
861
+ " <td>122</td>\n",
862
+ " <td>0.976700</td>\n",
863
+ " </tr>\n",
864
+ " <tr>\n",
865
+ " <td>123</td>\n",
866
+ " <td>0.998700</td>\n",
867
+ " </tr>\n",
868
+ " <tr>\n",
869
+ " <td>124</td>\n",
870
+ " <td>0.930500</td>\n",
871
+ " </tr>\n",
872
+ " <tr>\n",
873
+ " <td>125</td>\n",
874
+ " <td>0.953500</td>\n",
875
+ " </tr>\n",
876
+ " <tr>\n",
877
+ " <td>126</td>\n",
878
+ " <td>0.978000</td>\n",
879
+ " </tr>\n",
880
+ " <tr>\n",
881
+ " <td>127</td>\n",
882
+ " <td>0.967300</td>\n",
883
+ " </tr>\n",
884
+ " <tr>\n",
885
+ " <td>128</td>\n",
886
+ " <td>0.929400</td>\n",
887
+ " </tr>\n",
888
+ " <tr>\n",
889
+ " <td>129</td>\n",
890
+ " <td>0.963100</td>\n",
891
+ " </tr>\n",
892
+ " <tr>\n",
893
+ " <td>130</td>\n",
894
+ " <td>0.961500</td>\n",
895
+ " </tr>\n",
896
+ " <tr>\n",
897
+ " <td>131</td>\n",
898
+ " <td>0.978500</td>\n",
899
+ " </tr>\n",
900
+ " <tr>\n",
901
+ " <td>132</td>\n",
902
+ " <td>0.937200</td>\n",
903
+ " </tr>\n",
904
+ " <tr>\n",
905
+ " <td>133</td>\n",
906
+ " <td>0.953400</td>\n",
907
+ " </tr>\n",
908
+ " <tr>\n",
909
+ " <td>134</td>\n",
910
+ " <td>0.962000</td>\n",
911
+ " </tr>\n",
912
+ " <tr>\n",
913
+ " <td>135</td>\n",
914
+ " <td>0.950700</td>\n",
915
+ " </tr>\n",
916
+ " <tr>\n",
917
+ " <td>136</td>\n",
918
+ " <td>0.925100</td>\n",
919
+ " </tr>\n",
920
+ " <tr>\n",
921
+ " <td>137</td>\n",
922
+ " <td>0.958800</td>\n",
923
+ " </tr>\n",
924
+ " <tr>\n",
925
+ " <td>138</td>\n",
926
+ " <td>0.926200</td>\n",
927
+ " </tr>\n",
928
+ " <tr>\n",
929
+ " <td>139</td>\n",
930
+ " <td>0.930600</td>\n",
931
+ " </tr>\n",
932
+ " <tr>\n",
933
+ " <td>140</td>\n",
934
+ " <td>0.968900</td>\n",
935
+ " </tr>\n",
936
+ " <tr>\n",
937
+ " <td>141</td>\n",
938
+ " <td>0.970400</td>\n",
939
+ " </tr>\n",
940
+ " <tr>\n",
941
+ " <td>142</td>\n",
942
+ " <td>0.927100</td>\n",
943
+ " </tr>\n",
944
+ " <tr>\n",
945
+ " <td>143</td>\n",
946
+ " <td>0.911800</td>\n",
947
+ " </tr>\n",
948
+ " <tr>\n",
949
+ " <td>144</td>\n",
950
+ " <td>0.953200</td>\n",
951
+ " </tr>\n",
952
+ " <tr>\n",
953
+ " <td>145</td>\n",
954
+ " <td>0.907100</td>\n",
955
+ " </tr>\n",
956
+ " <tr>\n",
957
+ " <td>146</td>\n",
958
+ " <td>0.935900</td>\n",
959
+ " </tr>\n",
960
+ " <tr>\n",
961
+ " <td>147</td>\n",
962
+ " <td>0.970600</td>\n",
963
+ " </tr>\n",
964
+ " <tr>\n",
965
+ " <td>148</td>\n",
966
+ " <td>0.920400</td>\n",
967
+ " </tr>\n",
968
+ " <tr>\n",
969
+ " <td>149</td>\n",
970
+ " <td>0.930200</td>\n",
971
+ " </tr>\n",
972
+ " <tr>\n",
973
+ " <td>150</td>\n",
974
+ " <td>0.926700</td>\n",
975
+ " </tr>\n",
976
+ " <tr>\n",
977
+ " <td>151</td>\n",
978
+ " <td>0.913400</td>\n",
979
+ " </tr>\n",
980
+ " <tr>\n",
981
+ " <td>152</td>\n",
982
+ " <td>0.926800</td>\n",
983
+ " </tr>\n",
984
+ " <tr>\n",
985
+ " <td>153</td>\n",
986
+ " <td>0.967200</td>\n",
987
+ " </tr>\n",
988
+ " <tr>\n",
989
+ " <td>154</td>\n",
990
+ " <td>0.939500</td>\n",
991
+ " </tr>\n",
992
+ " <tr>\n",
993
+ " <td>155</td>\n",
994
+ " <td>0.910600</td>\n",
995
+ " </tr>\n",
996
+ " <tr>\n",
997
+ " <td>156</td>\n",
998
+ " <td>0.926400</td>\n",
999
+ " </tr>\n",
1000
+ " <tr>\n",
1001
+ " <td>157</td>\n",
1002
+ " <td>0.935400</td>\n",
1003
+ " </tr>\n",
1004
+ " <tr>\n",
1005
+ " <td>158</td>\n",
1006
+ " <td>0.967700</td>\n",
1007
+ " </tr>\n",
1008
+ " <tr>\n",
1009
+ " <td>159</td>\n",
1010
+ " <td>0.899000</td>\n",
1011
+ " </tr>\n",
1012
+ " <tr>\n",
1013
+ " <td>160</td>\n",
1014
+ " <td>0.916600</td>\n",
1015
+ " </tr>\n",
1016
+ " <tr>\n",
1017
+ " <td>161</td>\n",
1018
+ " <td>0.961600</td>\n",
1019
+ " </tr>\n",
1020
+ " <tr>\n",
1021
+ " <td>162</td>\n",
1022
+ " <td>0.898200</td>\n",
1023
+ " </tr>\n",
1024
+ " <tr>\n",
1025
+ " <td>163</td>\n",
1026
+ " <td>0.944600</td>\n",
1027
+ " </tr>\n",
1028
+ " <tr>\n",
1029
+ " <td>164</td>\n",
1030
+ " <td>0.935700</td>\n",
1031
+ " </tr>\n",
1032
+ " <tr>\n",
1033
+ " <td>165</td>\n",
1034
+ " <td>0.922500</td>\n",
1035
+ " </tr>\n",
1036
+ " <tr>\n",
1037
+ " <td>166</td>\n",
1038
+ " <td>0.897600</td>\n",
1039
+ " </tr>\n",
1040
+ " <tr>\n",
1041
+ " <td>167</td>\n",
1042
+ " <td>0.968600</td>\n",
1043
+ " </tr>\n",
1044
+ " <tr>\n",
1045
+ " <td>168</td>\n",
1046
+ " <td>0.927400</td>\n",
1047
+ " </tr>\n",
1048
+ " <tr>\n",
1049
+ " <td>169</td>\n",
1050
+ " <td>0.910900</td>\n",
1051
+ " </tr>\n",
1052
+ " <tr>\n",
1053
+ " <td>170</td>\n",
1054
+ " <td>0.904700</td>\n",
1055
+ " </tr>\n",
1056
+ " <tr>\n",
1057
+ " <td>171</td>\n",
1058
+ " <td>0.899800</td>\n",
1059
+ " </tr>\n",
1060
+ " <tr>\n",
1061
+ " <td>172</td>\n",
1062
+ " <td>0.896400</td>\n",
1063
+ " </tr>\n",
1064
+ " <tr>\n",
1065
+ " <td>173</td>\n",
1066
+ " <td>0.862100</td>\n",
1067
+ " </tr>\n",
1068
+ " <tr>\n",
1069
+ " <td>174</td>\n",
1070
+ " <td>0.909100</td>\n",
1071
+ " </tr>\n",
1072
+ " <tr>\n",
1073
+ " <td>175</td>\n",
1074
+ " <td>0.903200</td>\n",
1075
+ " </tr>\n",
1076
+ " <tr>\n",
1077
+ " <td>176</td>\n",
1078
+ " <td>0.958600</td>\n",
1079
+ " </tr>\n",
1080
+ " <tr>\n",
1081
+ " <td>177</td>\n",
1082
+ " <td>0.902500</td>\n",
1083
+ " </tr>\n",
1084
+ " <tr>\n",
1085
+ " <td>178</td>\n",
1086
+ " <td>0.894900</td>\n",
1087
+ " </tr>\n",
1088
+ " <tr>\n",
1089
+ " <td>179</td>\n",
1090
+ " <td>0.937900</td>\n",
1091
+ " </tr>\n",
1092
+ " <tr>\n",
1093
+ " <td>180</td>\n",
1094
+ " <td>0.900700</td>\n",
1095
+ " </tr>\n",
1096
+ " <tr>\n",
1097
+ " <td>181</td>\n",
1098
+ " <td>0.922300</td>\n",
1099
+ " </tr>\n",
1100
+ " <tr>\n",
1101
+ " <td>182</td>\n",
1102
+ " <td>0.939300</td>\n",
1103
+ " </tr>\n",
1104
+ " <tr>\n",
1105
+ " <td>183</td>\n",
1106
+ " <td>0.932600</td>\n",
1107
+ " </tr>\n",
1108
+ " <tr>\n",
1109
+ " <td>184</td>\n",
1110
+ " <td>0.913300</td>\n",
1111
+ " </tr>\n",
1112
+ " <tr>\n",
1113
+ " <td>185</td>\n",
1114
+ " <td>0.941700</td>\n",
1115
+ " </tr>\n",
1116
+ " <tr>\n",
1117
+ " <td>186</td>\n",
1118
+ " <td>0.886300</td>\n",
1119
+ " </tr>\n",
1120
+ " <tr>\n",
1121
+ " <td>187</td>\n",
1122
+ " <td>0.918000</td>\n",
1123
+ " </tr>\n",
1124
+ " <tr>\n",
1125
+ " <td>188</td>\n",
1126
+ " <td>0.884000</td>\n",
1127
+ " </tr>\n",
1128
+ " <tr>\n",
1129
+ " <td>189</td>\n",
1130
+ " <td>0.947400</td>\n",
1131
+ " </tr>\n",
1132
+ " <tr>\n",
1133
+ " <td>190</td>\n",
1134
+ " <td>0.894500</td>\n",
1135
+ " </tr>\n",
1136
+ " <tr>\n",
1137
+ " <td>191</td>\n",
1138
+ " <td>0.929300</td>\n",
1139
+ " </tr>\n",
1140
+ " <tr>\n",
1141
+ " <td>192</td>\n",
1142
+ " <td>0.877300</td>\n",
1143
+ " </tr>\n",
1144
+ " <tr>\n",
1145
+ " <td>193</td>\n",
1146
+ " <td>0.894300</td>\n",
1147
+ " </tr>\n",
1148
+ " <tr>\n",
1149
+ " <td>194</td>\n",
1150
+ " <td>0.867800</td>\n",
1151
+ " </tr>\n",
1152
+ " <tr>\n",
1153
+ " <td>195</td>\n",
1154
+ " <td>0.913500</td>\n",
1155
+ " </tr>\n",
1156
+ " <tr>\n",
1157
+ " <td>196</td>\n",
1158
+ " <td>0.908100</td>\n",
1159
+ " </tr>\n",
1160
+ " <tr>\n",
1161
+ " <td>197</td>\n",
1162
+ " <td>0.931200</td>\n",
1163
+ " </tr>\n",
1164
+ " <tr>\n",
1165
+ " <td>198</td>\n",
1166
+ " <td>0.911000</td>\n",
1167
+ " </tr>\n",
1168
+ " <tr>\n",
1169
+ " <td>199</td>\n",
1170
+ " <td>0.941800</td>\n",
1171
+ " </tr>\n",
1172
+ " <tr>\n",
1173
+ " <td>200</td>\n",
1174
+ " <td>0.913000</td>\n",
1175
+ " </tr>\n",
1176
+ " <tr>\n",
1177
+ " <td>201</td>\n",
1178
+ " <td>0.921800</td>\n",
1179
+ " </tr>\n",
1180
+ " <tr>\n",
1181
+ " <td>202</td>\n",
1182
+ " <td>0.921700</td>\n",
1183
+ " </tr>\n",
1184
+ " <tr>\n",
1185
+ " <td>203</td>\n",
1186
+ " <td>0.914500</td>\n",
1187
+ " </tr>\n",
1188
+ " <tr>\n",
1189
+ " <td>204</td>\n",
1190
+ " <td>0.910500</td>\n",
1191
+ " </tr>\n",
1192
+ " <tr>\n",
1193
+ " <td>205</td>\n",
1194
+ " <td>0.906600</td>\n",
1195
+ " </tr>\n",
1196
+ " <tr>\n",
1197
+ " <td>206</td>\n",
1198
+ " <td>0.915100</td>\n",
1199
+ " </tr>\n",
1200
+ " <tr>\n",
1201
+ " <td>207</td>\n",
1202
+ " <td>0.881600</td>\n",
1203
+ " </tr>\n",
1204
+ " <tr>\n",
1205
+ " <td>208</td>\n",
1206
+ " <td>0.884700</td>\n",
1207
+ " </tr>\n",
1208
+ " <tr>\n",
1209
+ " <td>209</td>\n",
1210
+ " <td>0.902900</td>\n",
1211
+ " </tr>\n",
1212
+ " <tr>\n",
1213
+ " <td>210</td>\n",
1214
+ " <td>0.882600</td>\n",
1215
+ " </tr>\n",
1216
+ " <tr>\n",
1217
+ " <td>211</td>\n",
1218
+ " <td>0.891000</td>\n",
1219
+ " </tr>\n",
1220
+ " <tr>\n",
1221
+ " <td>212</td>\n",
1222
+ " <td>0.914400</td>\n",
1223
+ " </tr>\n",
1224
+ " <tr>\n",
1225
+ " <td>213</td>\n",
1226
+ " <td>0.930400</td>\n",
1227
+ " </tr>\n",
1228
+ " <tr>\n",
1229
+ " <td>214</td>\n",
1230
+ " <td>0.891100</td>\n",
1231
+ " </tr>\n",
1232
+ " <tr>\n",
1233
+ " <td>215</td>\n",
1234
+ " <td>0.859300</td>\n",
1235
+ " </tr>\n",
1236
+ " <tr>\n",
1237
+ " <td>216</td>\n",
1238
+ " <td>0.891800</td>\n",
1239
+ " </tr>\n",
1240
+ " <tr>\n",
1241
+ " <td>217</td>\n",
1242
+ " <td>0.873000</td>\n",
1243
+ " </tr>\n",
1244
+ " <tr>\n",
1245
+ " <td>218</td>\n",
1246
+ " <td>0.925900</td>\n",
1247
+ " </tr>\n",
1248
+ " <tr>\n",
1249
+ " <td>219</td>\n",
1250
+ " <td>0.905700</td>\n",
1251
+ " </tr>\n",
1252
+ " <tr>\n",
1253
+ " <td>220</td>\n",
1254
+ " <td>0.921200</td>\n",
1255
+ " </tr>\n",
1256
+ " <tr>\n",
1257
+ " <td>221</td>\n",
1258
+ " <td>0.890200</td>\n",
1259
+ " </tr>\n",
1260
+ " <tr>\n",
1261
+ " <td>222</td>\n",
1262
+ " <td>0.915800</td>\n",
1263
+ " </tr>\n",
1264
+ " <tr>\n",
1265
+ " <td>223</td>\n",
1266
+ " <td>0.887300</td>\n",
1267
+ " </tr>\n",
1268
+ " <tr>\n",
1269
+ " <td>224</td>\n",
1270
+ " <td>0.898300</td>\n",
1271
+ " </tr>\n",
1272
+ " <tr>\n",
1273
+ " <td>225</td>\n",
1274
+ " <td>0.865600</td>\n",
1275
+ " </tr>\n",
1276
+ " <tr>\n",
1277
+ " <td>226</td>\n",
1278
+ " <td>0.873900</td>\n",
1279
+ " </tr>\n",
1280
+ " <tr>\n",
1281
+ " <td>227</td>\n",
1282
+ " <td>0.904800</td>\n",
1283
+ " </tr>\n",
1284
+ " <tr>\n",
1285
+ " <td>228</td>\n",
1286
+ " <td>0.917900</td>\n",
1287
+ " </tr>\n",
1288
+ " <tr>\n",
1289
+ " <td>229</td>\n",
1290
+ " <td>0.923400</td>\n",
1291
+ " </tr>\n",
1292
+ " <tr>\n",
1293
+ " <td>230</td>\n",
1294
+ " <td>0.939700</td>\n",
1295
+ " </tr>\n",
1296
+ " <tr>\n",
1297
+ " <td>231</td>\n",
1298
+ " <td>0.913400</td>\n",
1299
+ " </tr>\n",
1300
+ " <tr>\n",
1301
+ " <td>232</td>\n",
1302
+ " <td>0.873100</td>\n",
1303
+ " </tr>\n",
1304
+ " <tr>\n",
1305
+ " <td>233</td>\n",
1306
+ " <td>0.896700</td>\n",
1307
+ " </tr>\n",
1308
+ " <tr>\n",
1309
+ " <td>234</td>\n",
1310
+ " <td>0.892100</td>\n",
1311
+ " </tr>\n",
1312
+ " <tr>\n",
1313
+ " <td>235</td>\n",
1314
+ " <td>0.902100</td>\n",
1315
+ " </tr>\n",
1316
+ " <tr>\n",
1317
+ " <td>236</td>\n",
1318
+ " <td>0.927200</td>\n",
1319
+ " </tr>\n",
1320
+ " <tr>\n",
1321
+ " <td>237</td>\n",
1322
+ " <td>0.912900</td>\n",
1323
+ " </tr>\n",
1324
+ " <tr>\n",
1325
+ " <td>238</td>\n",
1326
+ " <td>0.872900</td>\n",
1327
+ " </tr>\n",
1328
+ " <tr>\n",
1329
+ " <td>239</td>\n",
1330
+ " <td>0.904700</td>\n",
1331
+ " </tr>\n",
1332
+ " <tr>\n",
1333
+ " <td>240</td>\n",
1334
+ " <td>0.879600</td>\n",
1335
+ " </tr>\n",
1336
+ " <tr>\n",
1337
+ " <td>241</td>\n",
1338
+ " <td>0.879800</td>\n",
1339
+ " </tr>\n",
1340
+ " <tr>\n",
1341
+ " <td>242</td>\n",
1342
+ " <td>0.908800</td>\n",
1343
+ " </tr>\n",
1344
+ " <tr>\n",
1345
+ " <td>243</td>\n",
1346
+ " <td>0.909800</td>\n",
1347
+ " </tr>\n",
1348
+ " <tr>\n",
1349
+ " <td>244</td>\n",
1350
+ " <td>0.838400</td>\n",
1351
+ " </tr>\n",
1352
+ " <tr>\n",
1353
+ " <td>245</td>\n",
1354
+ " <td>0.889200</td>\n",
1355
+ " </tr>\n",
1356
+ " <tr>\n",
1357
+ " <td>246</td>\n",
1358
+ " <td>0.912900</td>\n",
1359
+ " </tr>\n",
1360
+ " <tr>\n",
1361
+ " <td>247</td>\n",
1362
+ " <td>0.879700</td>\n",
1363
+ " </tr>\n",
1364
+ " <tr>\n",
1365
+ " <td>248</td>\n",
1366
+ " <td>0.910700</td>\n",
1367
+ " </tr>\n",
1368
+ " <tr>\n",
1369
+ " <td>249</td>\n",
1370
+ " <td>0.845400</td>\n",
1371
+ " </tr>\n",
1372
+ " <tr>\n",
1373
+ " <td>250</td>\n",
1374
+ " <td>0.882200</td>\n",
1375
+ " </tr>\n",
1376
+ " </tbody>\n",
1377
+ "</table><p>"
1378
+ ],
1379
+ "text/plain": [
1380
+ "<IPython.core.display.HTML object>"
1381
+ ]
1382
+ },
1383
+ "metadata": {},
1384
+ "output_type": "display_data"
1385
+ }
1386
+ ],
1387
+ "source": [
1388
+ "trainer = transformers.Trainer(\n",
1389
+ " model = model,\n",
1390
+ " train_dataset = data,\n",
1391
+ " args = targs,\n",
1392
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
1393
+ ")\n",
1394
+ "trainer.train(resume_from_checkpoint=False)\n",
1395
+ "model.save_pretrained('sqllama-out2')"
1396
+ ]
1397
+ },
1398
+ {
1399
+ "cell_type": "code",
1400
+ "execution_count": 11,
1401
+ "metadata": {},
1402
+ "outputs": [
1403
+ {
1404
+ "name": "stderr",
1405
+ "output_type": "stream",
1406
+ "text": [
1407
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
1408
+ " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
1409
+ ]
1410
+ },
1411
+ {
1412
+ "name": "stdout",
1413
+ "output_type": "stream",
1414
+ "text": [
1415
+ "from model\n",
1416
+ "<unk>table: 1-12028543-3\n",
1417
+ "columns: Season,Cup FinalDate,WinningTeam,Score,LosingTeam,Location,Cup Final Attendance\n",
1418
+ "Q: Who was the winning team in the 1989 season?\n",
1419
+ "A: SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n",
1420
+ "END\n",
1421
+ "END\n",
1422
+ "END\n",
1423
+ "END\n",
1424
+ "\n",
1425
+ "expected answer\n",
1426
+ "SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n",
1427
+ "END\n",
1428
+ "\n",
1429
+ "from model\n",
1430
+ "<unk>table: 2-18096431-5\n",
1431
+ "columns: Place,Player,Country,Score,To par\n",
1432
+ "Q: What is To par, when Country is \"United States\", and when Player is \"Mark Brooks\"?\n",
1433
+ "A: 18-1\n",
1434
+ "END\n",
1435
+ "\n",
1436
+ "\n",
1437
+ "expected answer\n",
1438
+ "SELECT To par FROM 2-18096431-5 WHERE Country = 'united states' AND Player = 'mark brooks'\n",
1439
+ "END\n",
1440
+ "\n",
1441
+ "from model\n",
1442
+ "<unk>table: 2-10701914-2\n",
1443
+ "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n",
1444
+ "Q: What home team played at the western oval?\n",
1445
+ "A: Western Bulldogs\n",
1446
+ "END\n",
1447
+ "END\n",
1448
+ "END\n",
1449
+ "END\n",
1450
+ "END\n",
1451
+ "END\n",
1452
+ "END\n",
1453
+ "END\n",
1454
+ "END\n",
1455
+ "END\n",
1456
+ "END\n",
1457
+ "END\n",
1458
+ "END\n",
1459
+ "END\n",
1460
+ "END\n",
1461
+ "END\n",
1462
+ "END\n",
1463
+ "\n",
1464
+ "\n",
1465
+ "expected answer\n",
1466
+ "SELECT Home team FROM 2-10701914-2 WHERE Venue = 'western oval'\n",
1467
+ "END\n",
1468
+ "\n",
1469
+ "from model\n",
1470
+ "<unk>table: 1-29598261-1\n",
1471
+ "columns: Name,Number,Position,Height,Weight,Year,Hometown,Last School/College\n",
1472
+ "Q: what is the school for chris mcnamara?\n",
1473
+ "A: SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'chris mcnamara'\n",
1474
+ "END\n",
1475
+ "END\n",
1476
+ "END\n",
1477
+ "END\n",
1478
+ "\n",
1479
+ "\n",
1480
+ "expected answer\n",
1481
+ "SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'Chris McNamara'\n",
1482
+ "END\n",
1483
+ "\n",
1484
+ "from model\n",
1485
+ "<unk>table: 1-27722408-11\n",
1486
+ "columns: Game,Date,Team,Score,High points,High rebounds,High assists,Location Attendance,Record\n",
1487
+ "Q: Who had the most assists and how many did they have on April 8?\n",
1488
+ "A: SELECT High assists FROM 1-27722408-11 WHERE Date = 'april 8'\n",
1489
+ "END\n",
1490
+ "\n",
1491
+ "\n",
1492
+ "expected answer\n",
1493
+ "SELECT High assists FROM 1-27722408-11 WHERE Date = 'April 8'\n",
1494
+ "END\n",
1495
+ "\n",
1496
+ "from model\n",
1497
+ "<unk>table: 1-21378339-5\n",
1498
+ "columns: Draw,Song,Artist,Panel Points,Televotes,Televote Points,Score,Placing\n",
1499
+ "Q: Name the number of artists for panel points being 5\n",
1500
+ "A: SELECT COUNT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n",
1501
+ "END\n",
1502
+ "END\n",
1503
+ "END\n",
1504
+ "END\n",
1505
+ "END\n",
1506
+ "\n",
1507
+ "expected answer\n",
1508
+ "SELECT COUNT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n",
1509
+ "END\n",
1510
+ "\n",
1511
+ "from model\n",
1512
+ "<unk>table: 2-11545282-17\n",
1513
+ "columns: Player,Nationality,Position,Years for Jazz,School/Club Team\n",
1514
+ "Q: What position does Michael Ruffin play?\n",
1515
+ "A: SELECT Position FROM 2-11545282-17 WHERE Player = 'michael ruffin'\n",
1516
+ "END\n",
1517
+ "END\n",
1518
+ "END\n",
1519
+ "END\n",
1520
+ "END\n",
1521
+ "END\n",
1522
+ "END\n",
1523
+ "END\n",
1524
+ "\n",
1525
+ "\n",
1526
+ "expected answer\n",
1527
+ "SELECT Position FROM 2-11545282-17 WHERE Player = 'michael ruffin'\n",
1528
+ "END\n",
1529
+ "\n",
1530
+ "from model\n",
1531
+ "<unk>table: 1-17801022-1\n",
1532
+ "columns: Year,Date,Driver,Manufacturer,Laps,Miles (km),Race Time,Average Speed (mph)\n",
1533
+ "Q: What manufacturer won the race on November 2?\n",
1534
+ "A: SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'november 2'\n",
1535
+ "END\n",
1536
+ "END\n",
1537
+ "END\n",
1538
+ "\n",
1539
+ "expected answer\n",
1540
+ "SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'November 2'\n",
1541
+ "END\n",
1542
+ "\n",
1543
+ "from model\n",
1544
+ "<unk>table: 2-10806592-14\n",
1545
+ "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n",
1546
+ "Q: What was the away score when the home team was Melbourne?\n",
1547
+ "A: SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n",
1548
+ "END\n",
1549
+ "END\n",
1550
+ "END\n",
1551
+ "\n",
1552
+ "\n",
1553
+ "expected answer\n",
1554
+ "SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n",
1555
+ "END\n",
1556
+ "\n",
1557
+ "from model\n",
1558
+ "<unk>table: 2-17978030-6\n",
1559
+ "columns: Date,Time,Score,Set 1,Set 2,Set 3,Total\n",
1560
+ "Q: What is the score when the set 3 is 26–28?\n",
1561
+ "A: SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n",
1562
+ "END\n",
1563
+ "END\n",
1564
+ "Q: What\n",
1565
+ "\n",
1566
+ "expected answer\n",
1567
+ "SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n",
1568
+ "END\n",
1569
+ "\n"
1570
+ ]
1571
+ }
1572
+ ],
1573
+ "source": [
1574
+ "def get_query(q):\n",
1575
+ " \n",
1576
+ " toks = tokenizer(q , return_tensors='pt')\n",
1577
+ " ctoks = toks.input_ids.to('cuda')\n",
1578
+ " gen = model.generate(ctoks, max_length=100)\n",
1579
+ " return tokenizer.decode(gen[0])\n",
1580
+ "\n",
1581
+ "M = len(q_red)\n",
1582
+ "\n",
1583
+ "for _ in range(10):\n",
1584
+ " j = random.randint(0,M-1)\n",
1585
+ " qs = q_red[j]\n",
1586
+ " a = a_red[j]\n",
1587
+ "\n",
1588
+ " ma = get_query(qs)\n",
1589
+ "\n",
1590
+ " #print(qs)\n",
1591
+ " print('from model')\n",
1592
+ " print(ma)\n",
1593
+ " print()\n",
1594
+ " print('expected answer')\n",
1595
+ " print(a)\n"
1596
  ]
1597
  }
1598
  ],