matt-tries-dl commited on
Commit
7dd7ab4
1 Parent(s): 4f1cd24
llama_test.ipynb CHANGED
@@ -2307,6 +2307,63 @@
2307
  "trainer.train(resume_from_checkpoint=False)\n",
2308
  "model.save_pretrained('sqllama-out')"
2309
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2310
  }
2311
  ],
2312
  "metadata": {
 
2307
  "trainer.train(resume_from_checkpoint=False)\n",
2308
  "model.save_pretrained('sqllama-out')"
2309
  ]
2310
+ },
2311
+ {
2312
+ "cell_type": "code",
2313
+ "execution_count": 7,
2314
+ "metadata": {},
2315
+ "outputs": [
2316
+ {
2317
+ "name": "stderr",
2318
+ "output_type": "stream",
2319
+ "text": [
2320
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
2321
+ " \"You have modified the pretrained model configuration to control generation. This is a\"\n",
2322
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
2323
+ " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
2324
+ ]
2325
+ },
2326
+ {
2327
+ "name": "stdout",
2328
+ "output_type": "stream",
2329
+ "text": [
2330
+ "Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
2331
+ "### Question: What county has a CERCLIS ID of scd037405362?\n",
2332
+ "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n",
2333
+ "### Answer: \n",
2334
+ "<unk>Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
2335
+ "### Question: What county has a CERCLIS ID of scd037405362?\n",
2336
+ "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n",
2337
+ "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID = 'scd037405362' \n",
2338
+ "### Question: What county has a CERCLIS ID of scd037405362?\n",
2339
+ "### Input: Table 2-11960788-1 has columns CERCLIS ID (text),Name (text),County (text),Proposed (text),Listed (text). \n",
2340
+ "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID\n",
2341
+ "\n",
2342
+ "### Answer: SELECT County FROM 2-11960788-1 WHERE CERCLIS ID = 'scd037405362'\n"
2343
+ ]
2344
+ }
2345
+ ],
2346
+ "source": [
2347
+ "def get_query(q):\n",
2348
+ " \n",
2349
+ " toks = tokenizer(q , return_tensors='pt')\n",
2350
+ " ctoks = toks.input_ids.to('cuda')\n",
2351
+ " gen = model.generate(ctoks, max_length=256)\n",
2352
+ " return tokenizer.decode(gen[0])\n",
2353
+ "\n",
2354
+ "M = len(nl_q)\n",
2355
+ "j = random.randint(0,M-1)\n",
2356
+ "qs = nl_q[j] + '\\n### Answer: '\n",
2357
+ "a = sql_a[j]\n",
2358
+ "\n",
2359
+ "ma = get_query(qs)\n",
2360
+ "\n",
2361
+ "#print(qs)\n",
2362
+ "print('from model')\n",
2363
+ "print(ma)\n",
2364
+ "print('expected answer')\n",
2365
+ "print(a)\n"
2366
+ ]
2367
  }
2368
  ],
2369
  "metadata": {
sqllama-out2/adapter_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name_or_path": "decapoda-research/llama-7b-hf",
3
+ "bias": "none",
4
+ "enable_lora": null,
5
+ "fan_in_fan_out": false,
6
+ "inference_mode": true,
7
+ "lora_alpha": 16,
8
+ "lora_dropout": 0.1,
9
+ "merge_weights": false,
10
+ "modules_to_save": null,
11
+ "peft_type": "LORA",
12
+ "r": 4,
13
+ "target_modules": [
14
+ "q_proj",
15
+ "v_proj"
16
+ ],
17
+ "task_type": "CASUAL_LM"
18
+ }
sqllama-out2/adapter_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee15525f45ab11e3e7ba334c0639b7263ea25ae0d42aa22f801022020ffc493
3
+ size 8434381
wikisql.ipynb ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 11,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 11,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "import torch.nn as nn\n",
22
+ "torch.cuda.is_available()"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "name": "stderr",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n",
35
+ " warn(msg)\n",
36
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
37
+ "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n",
38
+ "The class this function is called from is 'LlamaTokenizer'.\n"
39
+ ]
40
+ },
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "\n",
46
+ "===================================BUG REPORT===================================\n",
47
+ "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
48
+ "================================================================================\n",
49
+ "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n",
50
+ "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
51
+ "CUDA SETUP: Detected CUDA version 113\n",
52
+ "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n"
53
+ ]
54
+ },
55
+ {
56
+ "data": {
57
+ "application/vnd.jupyter.widget-view+json": {
58
+ "model_id": "a9428ee09f334655b6b261d478cbd3d0",
59
+ "version_major": 2,
60
+ "version_minor": 0
61
+ },
62
+ "text/plain": [
63
+ "Loading checkpoint shards: 0%| | 0/33 [00:00<?, ?it/s]"
64
+ ]
65
+ },
66
+ "metadata": {},
67
+ "output_type": "display_data"
68
+ }
69
+ ],
70
+ "source": [
71
+ "from transformers import LlamaTokenizer, LlamaForCausalLM\n",
72
+ "from peft import prepare_model_for_int8_training\n",
73
+ "tokenizer = LlamaTokenizer.from_pretrained(\n",
74
+ " \"decapoda-research/llama-7b-hf\")\n",
75
+ " \n",
76
+ "tokenizer.pad_token_id = 0\n",
77
+ "tokenizer.padding_side = 'left'\n",
78
+ "\n",
79
+ "model = LlamaForCausalLM.from_pretrained(\n",
80
+ " \"decapoda-research/llama-7b-hf\",\n",
81
+ " load_in_8bit=True,\n",
82
+ " device_map=\"auto\",\n",
83
+ " torch_dtype=torch.float16\n",
84
+ ")\n",
85
+ "\n",
86
+ "model = prepare_model_for_int8_training(model)"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 3,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "\n",
99
+ "table: 2-13081928-2\n",
100
+ "columns: Country,Chart,Period,Peak position,Sales\n",
101
+ "Q: Name the period for Chart of g-music j-pop/k-pop chart\n",
102
+ "A: SELECT Period FROM 2-13081928-2 WHERE Chart = 'g-music j-pop/k-pop chart'\n",
103
+ "\n",
104
+ "table: 2-13612447-1\n",
105
+ "columns: Fraction,Ellipsis,Vinculum,Dots,Parentheses\n",
106
+ "Q: What is the dot value when the ellipsis is 0.012345679…?\n",
107
+ "A: SELECT Dots FROM 2-13612447-1 WHERE Ellipsis = '0.012345679…'\n",
108
+ "\n",
109
+ "table: 1-168274-1\n",
110
+ "columns: Company,ICB Sector,Ticker symbol,Index weighting (%) at 17 January 2013,Market cap. at April 2013 (€)\n",
111
+ "Q: Name the total number of index weighting % at 17 january 2013 for bouygues\n",
112
+ "A: SELECT COUNT Index weighting (%) at 17 January 2013 FROM 1-168274-1 WHERE Company = 'Bouygues'\n",
113
+ "\n",
114
+ "table: 2-15826191-2\n",
115
+ "columns: Rank,Nation,Gold,Silver,Bronze,Total\n",
116
+ "Q: What is the lowest gold when there are 0 bronze and the total is less than 2, and silver is less than 0?\n",
117
+ "A: SELECT MIN Gold FROM 2-15826191-2 WHERE Bronze = 0 AND Total < 2 AND Silver < 0\n",
118
+ "\n",
119
+ "table: 2-16387912-1\n",
120
+ "columns: Home team,Home team score,Away team,Away team score,Ground,Date,Time\n",
121
+ "Q: What is Ground, when Away Team is Sydney?\n",
122
+ "A: SELECT Ground FROM 2-16387912-1 WHERE Away team = 'sydney'\n"
123
+ ]
124
+ }
125
+ ],
126
+ "source": [
127
+ "import random\n",
128
+ "import json\n",
129
+ "\n",
130
+ "# defined by WikiSQL\n",
131
+ "\n",
132
+ "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n",
133
+ "cond_ops = ['=', '>', '<', 'OP']\n",
134
+ "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n",
135
+ "\n",
136
+ "def fix_repr(d,cols,types,tid):\n",
137
+ " sel_index=d['sel'] \n",
138
+ " agg_index=d['agg']\n",
139
+ " conditions=d['conds']\n",
140
+ " col = cols[sel_index]\n",
141
+ " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n",
142
+ " agg=agg_ops[agg_index],\n",
143
+ " sel=col,\n",
144
+ " tid=tid\n",
145
+ " )\n",
146
+ " if conditions:\n",
147
+ " cs = []\n",
148
+ " for i, o, v in conditions:\n",
149
+ " #print(i,cols)\n",
150
+ " nm = cols[i]\n",
151
+ " op = cond_ops[o]\n",
152
+ " \n",
153
+ " if types[i] in ['text']:\n",
154
+ " val = f\"\\'{v}\\'\"\n",
155
+ " else:\n",
156
+ " val = v\n",
157
+ " cs.append(f'{nm} {op} {val}')\n",
158
+ " #print(cs)\n",
159
+ "\n",
160
+ " rep += ' WHERE ' + ' AND '.join(cs)\n",
161
+ " \n",
162
+ " return rep\n",
163
+ "\n",
164
+ "tbl_cols = {}\n",
165
+ "tbl_types = {}\n",
166
+ "tbl_str = {}\n",
167
+ "\n",
168
+ "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n",
169
+ "\n",
170
+ "def tbl_def_to_string(id, header, types):\n",
171
+ " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n",
172
+ " return s\n",
173
+ "\n",
174
+ "with open('data/train.tables.jsonl') as f:\n",
175
+ " for line in f:\n",
176
+ " js = json.loads(line)\n",
177
+ " id = js['id']\n",
178
+ " hdr = js['header']\n",
179
+ " ts = js['types']\n",
180
+ " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n",
181
+ " tbl_cols[id] = hdr\n",
182
+ " tbl_types[id] = ts\n",
183
+ "\n",
184
+ "q_s = []\n",
185
+ "a_s = []\n",
186
+ "\n",
187
+ "with open('data/train.jsonl') as f:\n",
188
+ " for line in f:\n",
189
+ " js = json.loads(line)\n",
190
+ " id = js['table_id']\n",
191
+ " s = tbl_str[id]\n",
192
+ " qst = js['question']\n",
193
+ " nl = s + '\\nQ: ' + qst + '\\nA: '\n",
194
+ " q_s.append(nl)\n",
195
+ "\n",
196
+ " sql = js['sql']\n",
197
+ " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
198
+ " a = a + \"\\nEND\\n\"\n",
199
+ " a_s.append(a)\n",
200
+ "\n",
201
+ "M = len(q_s)\n",
202
+ "\n",
203
+ "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n",
204
+ "\n",
205
+ "for i in range(5):\n",
206
+ " j = random.randint(0,M-1)\n",
207
+ " print()\n",
208
+ " print(data_txt[j]) \n",
209
+ " \n",
210
+ " "
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": 4,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "toks = [tokenizer(s) for s in data_txt]\n"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 5,
225
+ "metadata": {},
226
+ "outputs": [
227
+ {
228
+ "name": "stdout",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "89\n",
232
+ " 0\n",
233
+ "count 56355.000000\n",
234
+ "mean 98.219519\n",
235
+ "std 21.740325\n",
236
+ "min 60.000000\n",
237
+ "25% 84.500000\n",
238
+ "50% 94.000000\n",
239
+ "75% 106.000000\n",
240
+ "max 458.000000\n",
241
+ "35608\n"
242
+ ]
243
+ }
244
+ ],
245
+ "source": [
246
+ "import numpy as np\n",
247
+ "import pandas as pd\n",
248
+ "\n",
249
+ "print(len(toks[0].input_ids))\n",
250
+ "lens = np.array([len(tok.input_ids) for tok in toks])\n",
251
+ "print(pd.DataFrame(lens).describe())\n",
252
+ "\n",
253
+ "z = zip(q_s,lens)\n",
254
+ "q_red = [a for a,b in z if b < 100]\n",
255
+ "z = zip(a_s,lens)\n",
256
+ "a_red = [a for a,b in z if b < 100]\n",
257
+ "\n",
258
+ "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n",
259
+ "print(len(data_red))\n",
260
+ "\n"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 7,
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "data": {
270
+ "application/vnd.jupyter.widget-view+json": {
271
+ "model_id": "d548eb2af20f435fa1af81e9045a2d0e",
272
+ "version_major": 2,
273
+ "version_minor": 0
274
+ },
275
+ "text/plain": [
276
+ "Map: 0%| | 0/1000 [00:00<?, ? examples/s]"
277
+ ]
278
+ },
279
+ "metadata": {},
280
+ "output_type": "display_data"
281
+ }
282
+ ],
283
+ "source": [
284
+ "import random, datasets\n",
285
+ "d = {'prompt': random.sample(data_red, 1000)}\n",
286
+ "\n",
287
+ "tokenizer.pad_token_id = tokenizer.eos_token\n",
288
+ "\n",
289
+ "data = datasets.Dataset.from_dict(d)\n",
290
+ "data = data.map(lambda x:\n",
291
+ " tokenizer(\n",
292
+ " x['prompt'],\n",
293
+ " truncation=True,\n",
294
+ " max_length=100,\n",
295
+ " padding=\"max_length\"\n",
296
+ " ))\n",
297
+ "\n",
298
+ "data = data.remove_columns('prompt')\n"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 8,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "from peft import LoraConfig, get_peft_model\n",
308
+ "import transformers\n",
309
+ "import datasets\n",
310
+ "\n",
311
+ "LORA_R = 4\n",
312
+ "LORA_ALPHA = 16\n",
313
+ "LORA_DROPOUT = .1\n",
314
+ "CUTOFF_LEN = 256\n",
315
+ "BATCH = 128\n",
316
+ "MICRO_BATCH = 4\n",
317
+ "N_GAS = BATCH//MICRO_BATCH\n",
318
+ "EPOCHS = 1\n",
319
+ "LR = 1e-4\n",
320
+ "\n",
321
+ "lora_cfg = LoraConfig(\n",
322
+ " r = LORA_R,\n",
323
+ " lora_alpha=LORA_ALPHA,\n",
324
+ " lora_dropout=LORA_DROPOUT,\n",
325
+ " task_type='CASUAL_LM',\n",
326
+ " target_modules=['q_proj','v_proj']\n",
327
+ ")\n",
328
+ "\n",
329
+ "model = get_peft_model(model,lora_cfg)\n",
330
+ "\n",
331
+ "targs = transformers.TrainingArguments(\n",
332
+ " per_device_train_batch_size=MICRO_BATCH,\n",
333
+ " gradient_accumulation_steps=N_GAS,\n",
334
+ " warmup_steps=0,\n",
335
+ " num_train_epochs=EPOCHS,\n",
336
+ " learning_rate=LR,\n",
337
+ " fp16=True,\n",
338
+ " logging_steps=1,\n",
339
+ " output_dir='sqllama-out2',\n",
340
+ " save_total_limit=3,\n",
341
+ " remove_unused_columns=False\n",
342
+ ")\n"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": 9,
348
+ "metadata": {},
349
+ "outputs": [
350
+ {
351
+ "data": {
352
+ "text/html": [
353
+ "\n",
354
+ " <div>\n",
355
+ " \n",
356
+ " <progress value='7' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
357
+ " [7/7 05:33, Epoch 0/1]\n",
358
+ " </div>\n",
359
+ " <table border=\"1\" class=\"dataframe\">\n",
360
+ " <thead>\n",
361
+ " <tr style=\"text-align: left;\">\n",
362
+ " <th>Step</th>\n",
363
+ " <th>Training Loss</th>\n",
364
+ " </tr>\n",
365
+ " </thead>\n",
366
+ " <tbody>\n",
367
+ " <tr>\n",
368
+ " <td>1</td>\n",
369
+ " <td>2.710700</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <td>2</td>\n",
373
+ " <td>2.680400</td>\n",
374
+ " </tr>\n",
375
+ " <tr>\n",
376
+ " <td>3</td>\n",
377
+ " <td>2.684500</td>\n",
378
+ " </tr>\n",
379
+ " <tr>\n",
380
+ " <td>4</td>\n",
381
+ " <td>2.625600</td>\n",
382
+ " </tr>\n",
383
+ " <tr>\n",
384
+ " <td>5</td>\n",
385
+ " <td>2.609600</td>\n",
386
+ " </tr>\n",
387
+ " <tr>\n",
388
+ " <td>6</td>\n",
389
+ " <td>2.619100</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <td>7</td>\n",
393
+ " <td>2.603800</td>\n",
394
+ " </tr>\n",
395
+ " </tbody>\n",
396
+ "</table><p>"
397
+ ],
398
+ "text/plain": [
399
+ "<IPython.core.display.HTML object>"
400
+ ]
401
+ },
402
+ "metadata": {},
403
+ "output_type": "display_data"
404
+ }
405
+ ],
406
+ "source": [
407
+ "trainer = transformers.Trainer(\n",
408
+ " model = model,\n",
409
+ " train_dataset = data,\n",
410
+ " args = targs,\n",
411
+ " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
412
+ ")\n",
413
+ "trainer.train(resume_from_checkpoint=False)\n",
414
+ "model.save_pretrained('sqllama-out2')"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": 10,
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "name": "stderr",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
427
+ " \"You have modified the pretrained model configuration to control generation. This is a\"\n",
428
+ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
429
+ " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n"
430
+ ]
431
+ },
432
+ {
433
+ "name": "stdout",
434
+ "output_type": "stream",
435
+ "text": [
436
+ "from model\n",
437
+ " ⁇ table: 1-25800134-1\n",
438
+ "columns: Series #,Season #,Title,Director,Writer(s),Airdate\n",
439
+ "Q: Who wrote the episode with series number 56?\n",
440
+ "A: 56-101, \"The Cage\", Gene Roddenberry\n",
441
+ "Q: Who wrote the episode with series number 56? (2)\n",
442
+ "A: 56-101,\n",
443
+ "expected answer SELECT Writer(s) FROM 1-25800134-1 WHERE Series # = 56\n"
444
+ ]
445
+ }
446
+ ],
447
+ "source": [
448
+ "def get_query(q):\n",
449
+ " \n",
450
+ " toks = tokenizer(q , return_tensors='pt')\n",
451
+ " ctoks = toks.input_ids.to('cuda')\n",
452
+ " gen = model.generate(ctoks, max_length=100)\n",
453
+ " return tokenizer.decode(gen[0])\n",
454
+ "\n",
455
+ "M = len(q_red)\n",
456
+ "j = random.randint(0,M-1)\n",
457
+ "qs = q_red[j]\n",
458
+ "a = a_red[j]\n",
459
+ "\n",
460
+ "ma = get_query(qs)\n",
461
+ "\n",
462
+ "#print(qs)\n",
463
+ "print('from model')\n",
464
+ "print(ma)\n",
465
+ "print\n",
466
+ "print('expected answer',a)\n"
467
+ ]
468
+ }
469
+ ],
470
+ "metadata": {
471
+ "kernelspec": {
472
+ "display_name": ".venv",
473
+ "language": "python",
474
+ "name": "python3"
475
+ },
476
+ "language_info": {
477
+ "codemirror_mode": {
478
+ "name": "ipython",
479
+ "version": 3
480
+ },
481
+ "file_extension": ".py",
482
+ "mimetype": "text/x-python",
483
+ "name": "python",
484
+ "nbconvert_exporter": "python",
485
+ "pygments_lexer": "ipython3",
486
+ "version": "3.7.3"
487
+ },
488
+ "orig_nbformat": 4,
489
+ "vscode": {
490
+ "interpreter": {
491
+ "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90"
492
+ }
493
+ }
494
+ },
495
+ "nbformat": 4,
496
+ "nbformat_minor": 2
497
+ }