matt-tries-dl
commited on
Commit
•
b444d89
1
Parent(s):
357d6d7
update
Browse files- alpaca-lora +1 -0
- llama_test.ipynb +209 -29
- requirements.txt +2 -1
alpaca-lora
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 8bb8579e403dc78e37fe81ffbb253c413007323f
|
llama_test.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
@@ -11,7 +11,7 @@
|
|
11 |
"True"
|
12 |
]
|
13 |
},
|
14 |
-
"execution_count":
|
15 |
"metadata": {},
|
16 |
"output_type": "execute_result"
|
17 |
}
|
@@ -32,7 +32,7 @@
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
-
"execution_count":
|
36 |
"metadata": {},
|
37 |
"outputs": [
|
38 |
{
|
@@ -47,7 +47,7 @@
|
|
47 |
{
|
48 |
"data": {
|
49 |
"application/vnd.jupyter.widget-view+json": {
|
50 |
-
"model_id": "
|
51 |
"version_major": 2,
|
52 |
"version_minor": 0
|
53 |
},
|
@@ -83,7 +83,7 @@
|
|
83 |
},
|
84 |
{
|
85 |
"cell_type": "code",
|
86 |
-
"execution_count":
|
87 |
"metadata": {},
|
88 |
"outputs": [
|
89 |
{
|
@@ -132,7 +132,7 @@
|
|
132 |
},
|
133 |
{
|
134 |
"cell_type": "code",
|
135 |
-
"execution_count":
|
136 |
"metadata": {},
|
137 |
"outputs": [
|
138 |
{
|
@@ -168,7 +168,7 @@
|
|
168 |
},
|
169 |
{
|
170 |
"cell_type": "code",
|
171 |
-
"execution_count":
|
172 |
"metadata": {},
|
173 |
"outputs": [
|
174 |
{
|
@@ -232,7 +232,7 @@
|
|
232 |
},
|
233 |
{
|
234 |
"cell_type": "code",
|
235 |
-
"execution_count":
|
236 |
"metadata": {},
|
237 |
"outputs": [
|
238 |
{
|
@@ -240,25 +240,30 @@
|
|
240 |
"output_type": "stream",
|
241 |
"text": [
|
242 |
"\n",
|
243 |
-
"
|
244 |
-
"
|
245 |
-
"
|
|
|
246 |
"\n",
|
247 |
-
"
|
248 |
-
"
|
249 |
-
"
|
|
|
250 |
"\n",
|
251 |
-
"
|
252 |
-
"
|
253 |
-
"
|
|
|
254 |
"\n",
|
255 |
-
"
|
256 |
-
"
|
257 |
-
"
|
|
|
258 |
"\n",
|
259 |
-
"
|
260 |
-
"
|
261 |
-
"
|
|
|
262 |
]
|
263 |
}
|
264 |
],
|
@@ -303,11 +308,11 @@
|
|
303 |
"tbl_types = {}\n",
|
304 |
"tbl_str = {}\n",
|
305 |
"\n",
|
306 |
-
"prefix = '
|
307 |
"\n",
|
308 |
"def tbl_def_to_string(id, header, types):\n",
|
309 |
" ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n",
|
310 |
-
" s = f'
|
311 |
" return s\n",
|
312 |
"\n",
|
313 |
"with open('data/train.tables.jsonl') as f:\n",
|
@@ -330,26 +335,201 @@
|
|
330 |
" id = js['table_id']\n",
|
331 |
" s = tbl_str[id]\n",
|
332 |
" qst = js['question']\n",
|
333 |
-
" nl = prefix +
|
334 |
" nl_q.append(nl)\n",
|
335 |
"\n",
|
336 |
" sql = js['sql']\n",
|
337 |
" a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
|
338 |
-
" a = '
|
339 |
" sql_a.append(a)\n",
|
340 |
"\n",
|
341 |
"\n",
|
342 |
"M = len(nl_q)\n",
|
343 |
"\n",
|
|
|
344 |
"\n",
|
345 |
"for i in range(5):\n",
|
346 |
" j = random.randint(0,M-1)\n",
|
347 |
" print()\n",
|
348 |
-
" print(
|
349 |
-
" print(sql_a[j]) \n",
|
350 |
" \n",
|
351 |
" "
|
352 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
}
|
354 |
],
|
355 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
|
|
11 |
"True"
|
12 |
]
|
13 |
},
|
14 |
+
"execution_count": 1,
|
15 |
"metadata": {},
|
16 |
"output_type": "execute_result"
|
17 |
}
|
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
+
"execution_count": 2,
|
36 |
"metadata": {},
|
37 |
"outputs": [
|
38 |
{
|
|
|
47 |
{
|
48 |
"data": {
|
49 |
"application/vnd.jupyter.widget-view+json": {
|
50 |
+
"model_id": "3ab80e2a1c0744e0af747ba63429a2af",
|
51 |
"version_major": 2,
|
52 |
"version_minor": 0
|
53 |
},
|
|
|
83 |
},
|
84 |
{
|
85 |
"cell_type": "code",
|
86 |
+
"execution_count": 3,
|
87 |
"metadata": {},
|
88 |
"outputs": [
|
89 |
{
|
|
|
132 |
},
|
133 |
{
|
134 |
"cell_type": "code",
|
135 |
+
"execution_count": 13,
|
136 |
"metadata": {},
|
137 |
"outputs": [
|
138 |
{
|
|
|
168 |
},
|
169 |
{
|
170 |
"cell_type": "code",
|
171 |
+
"execution_count": 4,
|
172 |
"metadata": {},
|
173 |
"outputs": [
|
174 |
{
|
|
|
232 |
},
|
233 |
{
|
234 |
"cell_type": "code",
|
235 |
+
"execution_count": 5,
|
236 |
"metadata": {},
|
237 |
"outputs": [
|
238 |
{
|
|
|
240 |
"output_type": "stream",
|
241 |
"text": [
|
242 |
"\n",
|
243 |
+
"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
|
244 |
+
"### Question: What is the Displacement of the Iveco F1CE3481E Engine?\n",
|
245 |
+
"### Input: Table 2-1415821-6 has columns Model (text),Engine (text),Displacement (text),Valvetrain (text),Fuel system (text),Max. power at rpm (text),Max. torque at rpm (text). \n",
|
246 |
+
"### Answer: SELECT Displacement FROM 2-1415821-6 WHERE Engine = 'iveco f1ce3481e'\n",
|
247 |
"\n",
|
248 |
+
"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
|
249 |
+
"### Question: What is the record of team utah?\n",
|
250 |
+
"### Input: Table 2-17355628-9 has columns Game (real),Date (text),Team (text),Score (text),High points (text),High rebounds (text),High assists (text),Location Attendance (text),Record (text). \n",
|
251 |
+
"### Answer: SELECT Record FROM 2-17355628-9 WHERE Team = 'utah'\n",
|
252 |
"\n",
|
253 |
+
"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
|
254 |
+
"### Question: What is the home of the team with a 16-8 record?\n",
|
255 |
+
"### Input: Table 2-16188254-4 has columns Date (text),Visitor (text),Score (text),Home (text),Leading scorer (text),Attendance (text),Record (text). \n",
|
256 |
+
"### Answer: SELECT Home FROM 2-16188254-4 WHERE Record = '16-8'\n",
|
257 |
"\n",
|
258 |
+
"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
|
259 |
+
"### Question: What week did the Galaxy play the Amsterdam Admirals?\n",
|
260 |
+
"### Input: Table 1-24814477-2 has columns Week (real),Date (text),Kickoff (text),Opponent (text),Final score (text),Team record (text),Game site (text),Attendance (real). \n",
|
261 |
+
"### Answer: SELECT Week FROM 1-24814477-2 WHERE Opponent = 'Amsterdam Admirals'\n",
|
262 |
"\n",
|
263 |
+
"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\n",
|
264 |
+
"### Question: How many caps did Mitchell Duke have overall?\n",
|
265 |
+
"### Input: Table 2-1257177-1 has columns Player (text),Country (text),Caps (real),Goals (text),Years Active (text). \n",
|
266 |
+
"### Answer: SELECT COUNT Caps FROM 2-1257177-1 WHERE Player = 'mitchell duke'\n"
|
267 |
]
|
268 |
}
|
269 |
],
|
|
|
308 |
"tbl_types = {}\n",
|
309 |
"tbl_str = {}\n",
|
310 |
"\n",
|
311 |
+
"prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n",
|
312 |
"\n",
|
313 |
"def tbl_def_to_string(id, header, types):\n",
|
314 |
" ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n",
|
315 |
+
" s = f'\\n### Input: Table {id} has columns ' + ','.join(ht) + '. '\n",
|
316 |
" return s\n",
|
317 |
"\n",
|
318 |
"with open('data/train.tables.jsonl') as f:\n",
|
|
|
335 |
" id = js['table_id']\n",
|
336 |
" s = tbl_str[id]\n",
|
337 |
" qst = js['question']\n",
|
338 |
+
" nl = prefix + \"\\n### Question: \" + qst + s\n",
|
339 |
" nl_q.append(nl)\n",
|
340 |
"\n",
|
341 |
" sql = js['sql']\n",
|
342 |
" a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
|
343 |
+
" a = '\\n### Answer: ' + a\n",
|
344 |
" sql_a.append(a)\n",
|
345 |
"\n",
|
346 |
"\n",
|
347 |
"M = len(nl_q)\n",
|
348 |
"\n",
|
349 |
+
"data_txt = [nl_q[i] + sql_a[i] for i in range(len(nl_q))]\n",
|
350 |
"\n",
|
351 |
"for i in range(5):\n",
|
352 |
" j = random.randint(0,M-1)\n",
|
353 |
" print()\n",
|
354 |
+
" print(data_txt[j]) \n",
|
|
|
355 |
" \n",
|
356 |
" "
|
357 |
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"attachments": {},
|
361 |
+
"cell_type": "markdown",
|
362 |
+
"metadata": {},
|
363 |
+
"source": [
|
364 |
+
"Set up the details for the model."
|
365 |
+
]
|
366 |
+
},
|
367 |
+
{
|
368 |
+
"cell_type": "code",
|
369 |
+
"execution_count": 26,
|
370 |
+
"metadata": {},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"data": {
|
374 |
+
"application/vnd.jupyter.widget-view+json": {
|
375 |
+
"model_id": "4f44918087484dd58b958a64cabdecb6",
|
376 |
+
"version_major": 2,
|
377 |
+
"version_minor": 0
|
378 |
+
},
|
379 |
+
"text/plain": [
|
380 |
+
"Map: 0%| | 0/56355 [00:00<?, ? examples/s]"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
"metadata": {},
|
384 |
+
"output_type": "display_data"
|
385 |
+
}
|
386 |
+
],
|
387 |
+
"source": [
|
388 |
+
"from peft import LoraConfig, get_peft_model\n",
|
389 |
+
"import transformers\n",
|
390 |
+
"import datasets\n",
|
391 |
+
"\n",
|
392 |
+
"LORA_R = 4\n",
|
393 |
+
"LORA_ALPHA = 16\n",
|
394 |
+
"LORA_DROPOUT = .1\n",
|
395 |
+
"CUTOFF_LEN = 256\n",
|
396 |
+
"BATCH = 128\n",
|
397 |
+
"MICRO_BATCH = 4\n",
|
398 |
+
"N_GAS = BATCH//MICRO_BATCH\n",
|
399 |
+
"EPOCHS = 1\n",
|
400 |
+
"LR = 1e-5\n",
|
401 |
+
"\n",
|
402 |
+
"lora_cfg = LoraConfig(\n",
|
403 |
+
" r = LORA_R,\n",
|
404 |
+
" lora_alpha=LORA_ALPHA,\n",
|
405 |
+
" lora_dropout=LORA_DROPOUT,\n",
|
406 |
+
" task_type='CASUAL_LM',\n",
|
407 |
+
" target_modules=['q_proj','v_proj']\n",
|
408 |
+
")\n",
|
409 |
+
"\n",
|
410 |
+
"modad = get_peft_model(model,lora_cfg)\n",
|
411 |
+
"\n",
|
412 |
+
"tokenizer.pad_token_id = 0\n",
|
413 |
+
"\n",
|
414 |
+
"d = {'prompt': data_txt}\n",
|
415 |
+
"\n",
|
416 |
+
"data = datasets.Dataset.from_dict(d)\n",
|
417 |
+
"data = data.map(lambda x:\n",
|
418 |
+
" tokenizer(\n",
|
419 |
+
" x['prompt'],\n",
|
420 |
+
" truncation=True,\n",
|
421 |
+
" max_length=CUTOFF_LEN,\n",
|
422 |
+
" padding=\"max_length\"\n",
|
423 |
+
" ))\n",
|
424 |
+
"\n",
|
425 |
+
"#data.remove_columns('prompt')\n",
|
426 |
+
"\n",
|
427 |
+
"targs = transformers.TrainingArguments(\n",
|
428 |
+
" per_device_train_batch_size=MICRO_BATCH,\n",
|
429 |
+
" gradient_accumulation_steps=N_GAS,\n",
|
430 |
+
" warmup_steps=0,\n",
|
431 |
+
" num_train_epochs=EPOCHS,\n",
|
432 |
+
" learning_rate=LR,\n",
|
433 |
+
" fp16=True,\n",
|
434 |
+
" logging_steps=1,\n",
|
435 |
+
" output_dir='sqllama-out',\n",
|
436 |
+
" save_total_limit=3,\n",
|
437 |
+
" remove_unused_columns=False\n",
|
438 |
+
")\n",
|
439 |
+
"\n",
|
440 |
+
"\n",
|
441 |
+
"modad.config.use_cache = False"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"attachments": {},
|
446 |
+
"cell_type": "markdown",
|
447 |
+
"metadata": {},
|
448 |
+
"source": [
|
449 |
+
"ignore - just trying to figure out huggingface datasets"
|
450 |
+
]
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "code",
|
454 |
+
"execution_count": 27,
|
455 |
+
"metadata": {},
|
456 |
+
"outputs": [
|
457 |
+
{
|
458 |
+
"name": "stdout",
|
459 |
+
"output_type": "stream",
|
460 |
+
"text": [
|
461 |
+
"Dataset({\n",
|
462 |
+
" features: ['prompt', 'input_ids', 'attention_mask'],\n",
|
463 |
+
" num_rows: 56355\n",
|
464 |
+
"})\n",
|
465 |
+
"{'prompt': \"Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.\\n### Question: Tell me what the notes are for South Australia \\n### Input: Table 1-1000181-1 has columns State/territory (text),Text/background colour (text),Format (text),Current slogan (text),Current series (text),Notes (text). \\n### Answer: SELECT Notes FROM 1-1000181-1 WHERE Current slogan = 'SOUTH AUSTRALIA'\", 'input_ids': [0, 13866, 338, 263, 1139, 393, 16612, 263, 848, 2009, 29892, 3300, 2859, 411, 385, 1881, 393, 16612, 263, 3758, 1591, 29889, 29871, 14350, 263, 3758, 2346, 393, 5663, 17180, 278, 848, 29889, 13, 2277, 29937, 894, 29901, 24948, 592, 825, 278, 11486, 526, 363, 4275, 8314, 29871, 13, 2277, 29937, 10567, 29901, 6137, 29871, 29896, 29899, 29896, 29900, 29900, 29900, 29896, 29947, 29896, 29899, 29896, 756, 4341, 4306, 29914, 357, 768, 706, 313, 726, 511, 1626, 29914, 7042, 12384, 313, 726, 511, 5809, 313, 726, 511, 7583, 269, 1188, 273, 313, 726, 511, 7583, 3652, 313, 726, 511, 3664, 267, 313, 726, 467, 259, 13, 2277, 29937, 673, 29901, 5097, 29871, 8695, 3895, 29871, 29896, 29899, 29896, 29900, 29900, 29900, 29896, 29947, 29896, 29899, 29896, 5754, 9626, 269, 1188, 273, 353, 525, 6156, 2692, 29950, 319, 29965, 10810, 1964, 10764, 29915, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
|
466 |
+
]
|
467 |
+
}
|
468 |
+
],
|
469 |
+
"source": [
|
470 |
+
"print(data)\n",
|
471 |
+
"print(data[0])\n",
|
472 |
+
"\n",
|
473 |
+
"#from datasets import load_dataset\n",
|
474 |
+
"\n",
|
475 |
+
"\n",
|
476 |
+
"#!git clone https://github.com/tloen/alpaca-lora.git\n",
|
477 |
+
"#dalp = load_dataset(\"json\", data_files=\"alpaca-lora/alpaca_data.json\")\n",
|
478 |
+
"#print(dalp)\n",
|
479 |
+
"\n",
|
480 |
+
"#dalp = dalp.map(lambda x : {'blah':'blah'})\n",
|
481 |
+
"#print(dalp)"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "code",
|
486 |
+
"execution_count": 25,
|
487 |
+
"metadata": {},
|
488 |
+
"outputs": [
|
489 |
+
{
|
490 |
+
"name": "stderr",
|
491 |
+
"output_type": "stream",
|
492 |
+
"text": [
|
493 |
+
"/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
494 |
+
" FutureWarning,\n"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"ename": "ValueError",
|
499 |
+
"evalue": "Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`prompt` in this case) have excessive nesting (inputs type `list` where type `int` is expected).",
|
500 |
+
"output_type": "error",
|
501 |
+
"traceback": [
|
502 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
503 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
504 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mconvert_to_tensors\u001b[0;34m(self, tensor_type, prepend_batch_axis)\u001b[0m\n\u001b[1;32m 716\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 717\u001b[0;31m \u001b[0mtensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mas_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 718\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
505 |
+
"\u001b[0;31mValueError\u001b[0m: too many dimensions 'str'",
|
506 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
507 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
508 |
+
"\u001b[0;32m/var/tmp/ipykernel_2309/3549391384.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdata_collator\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransformers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataCollatorForLanguageModeling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmlm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m )\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'sqllama-out'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
509 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1664\u001b[0m \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1665\u001b[0m \u001b[0mtrial\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrial\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1666\u001b[0;31m \u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_keys_for_eval\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1667\u001b[0m )\n\u001b[1;32m 1668\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
510 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1897\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1898\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1899\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1900\u001b[0m \u001b[0mtotal_batched_samples\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1901\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrng_to_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
511 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;31m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
512 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 670\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 671\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 672\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
513 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
514 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/data/data_collator.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, features, return_tensors)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtf_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mreturn_tensors\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtorch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mreturn_tensors\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"np\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
515 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/data/data_collator.py\u001b[0m in \u001b[0;36mtorch_call\u001b[0;34m(self, examples)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[0;31m# Handle dict or lists with proper padding and conversion to tensor.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 729\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexamples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpad_to_multiple_of\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad_to_multiple_of\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 730\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 731\u001b[0m batch = {\n",
|
516 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mpad\u001b[0;34m(self, encoded_inputs, padding, max_length, pad_to_multiple_of, return_attention_mask, return_tensors, verbose)\u001b[0m\n\u001b[1;32m 3033\u001b[0m \u001b[0mbatch_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3034\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3035\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mBatchEncoding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3036\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3037\u001b[0m def create_token_type_ids_from_sequences(\n",
|
517 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, encoding, tensor_type, prepend_batch_axis, n_sequences)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_sequences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn_sequences\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtensor_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprepend_batch_axis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprepend_batch_axis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
518 |
+
"\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36mconvert_to_tensors\u001b[0;34m(self, tensor_type, prepend_batch_axis)\u001b[0m\n\u001b[1;32m 736\u001b[0m \u001b[0;34mf\" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;34m\" expected).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m ) from e\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 740\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
519 |
+
"\u001b[0;31mValueError\u001b[0m: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`prompt` in this case) have excessive nesting (inputs type `list` where type `int` is expected)."
|
520 |
+
]
|
521 |
+
}
|
522 |
+
],
|
523 |
+
"source": [
|
524 |
+
"trainer = transformers.Trainer(\n",
|
525 |
+
" model = modad,\n",
|
526 |
+
" train_dataset = data,\n",
|
527 |
+
" args = targs,\n",
|
528 |
+
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
|
529 |
+
")\n",
|
530 |
+
"trainer.train(resume_from_checkpoint=False)\n",
|
531 |
+
"model.save_pretrained('sqllama-out')"
|
532 |
+
]
|
533 |
}
|
534 |
],
|
535 |
"metadata": {
|
requirements.txt
CHANGED
@@ -5,8 +5,9 @@ torch
|
|
5 |
sentencepiece
|
6 |
transformers
|
7 |
accelerate
|
8 |
-
bitsandbytes
|
9 |
peft
|
|
|
10 |
tqdm
|
11 |
records
|
12 |
babel
|
|
|
5 |
sentencepiece
|
6 |
transformers
|
7 |
accelerate
|
8 |
+
bitsandbytes==0.37.2
|
9 |
peft
|
10 |
+
datasets
|
11 |
tqdm
|
12 |
records
|
13 |
babel
|