{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n", "PeftConfig(peft_type='LORA', base_model_name_or_path='decapoda-research/llama-7b-hf', task_type='CASUAL_LM', inference_mode=True)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4581e257066e463d9638f5d3a407a70e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3215cace7aef45b1b040a12f11509a7d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00\n", " \n", " \n", " [250/250 3:05:44, Epoch 0/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
11.121900
21.103000
31.128600
41.098400
51.094800
61.135100
71.114100
81.125100
91.122000
101.138500
111.114100
121.124700
131.133100
141.118900
151.121700
161.131600
171.137700
181.147000
191.113200
201.158000
211.114100
221.117700
231.151200
241.154200
251.153600
261.119700
271.138500
281.148200
291.103300
301.117700
311.116300
321.134400
331.125200
341.137900
351.150100
361.126200
371.129100
381.093200
391.153100
401.108000
411.137300
421.101600
431.140600
441.159900
451.112600
461.101200
471.088000
481.135300
491.118100
501.140300
511.104000
521.122900
531.162200
541.108500
551.121900
561.092100
571.109500
581.139400
591.120800
601.132200
611.138700
621.128700
631.122500
641.145800
651.135000
661.107900
671.120700
681.128000
691.107600
701.155700
711.142400
721.118900
731.129900
741.134400
751.105500
761.104100
771.100900
781.148200
791.116100
801.121700
811.154100
821.118900
831.109600
841.109300
851.147900
861.094300
871.130000
881.095100
891.145900
901.131600
911.114200
921.126600
931.100300
941.140900
951.132800
961.105900
971.106200
981.097400
991.114500
1001.113700
1011.093300
1021.121900
1031.133600
1041.131500
1051.136800
1061.130800
1071.102100
1081.128300
1091.163500
1101.144200
1111.125600
1121.119700
1131.111100
1141.122400
1151.142500
1161.124500
1171.117700
1181.130500
1191.118500
1201.097200
1211.123600
1221.135700
1231.153400
1241.088200
1251.123600
1261.143000
1271.121800
1281.091200
1291.116700
1301.124400
1311.139100
1321.119400
1331.115000
1341.133600
1351.100900
1361.095100
1371.142600
1381.097300
1391.113100
1401.150800
1411.149600
1421.106700
1431.086100
1441.134200
1451.096400
1461.099200
1471.168300
1481.105900
1491.119700
1501.100200
1511.089600
1521.128200
1531.148300
1541.119800
1551.102700
1561.107800
1571.113100
1581.156100
1591.091500
1601.118000
1611.145600
1621.115400
1631.121900
1641.130100
1651.123400
1661.090900
1671.144400
1681.125100
1691.110700
1701.134300
1711.092600
1721.123000
1731.080100
1741.104100
1751.105800
1761.156000
1771.104000
1781.118500
1791.123100
1801.117000
1811.122100
1821.141200
1831.135600
1841.093600
1851.156300
1861.095600
1871.128900
1881.101200
1891.149900
1901.112300
1911.117600
1921.090600
1931.097700
1941.084700
1951.128900
1961.126400
1971.113000
1981.107500
1991.160100
2001.125800
2011.125300
2021.127200
2031.114200
2041.114300
2051.119200
2061.114500
2071.086100
2081.096200
2091.115800
2101.094500
2111.106400
2121.121400
2131.137600
2141.107000
2151.095700
2161.083000
2171.088700
2181.133700
2191.115500
2201.152900
2211.100100
2221.112500
2231.119200
2241.122600
2251.100100
2261.082500
2271.094800
2281.123600
2291.124700
2301.148800
2311.109600
2321.096100
2331.123000
2341.102200
2351.113200
2361.150700
2371.131900
2381.107200
2391.137600
2401.094800
2411.068000
2421.122100
2431.153700
2441.045100
2451.131400
2461.134600
2471.105300
2481.108800
2491.080800
2501.119200

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer = transformers.Trainer(\n", " model = model,\n", " train_dataset = data,\n", " args = targs,\n", " #data_collator=transformers.DefaultDataCollator\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", ")\n", "trainer.train(resume_from_checkpoint=False)\n", "model.save_pretrained('sqllama-out3')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "from model\n", "table: 2-10806592-14\n", "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "Q: What was the away score when the home team was Melbourne?\n", "A: SELECT Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "FROM 2-1080\n", "\n", "expected answer\n", "SELECT Away team score FROM 2-10806592-14 WHERE Home team = 'melbourne'\n", "END\n", "\n", "from model\n", "table: 2-17978030-6\n", "columns: Date,Time,Score,Set 1,Set 2,Set 3,Total\n", "Q: What is the score when the set 3 is 26–28?\n", "A: SELECT Score FROM 2-17978030-6 WHERE Set 3 = 26–28\n", "END\n", "\\end{code}\n", "\n", "\n", "expected answer\n", "SELECT Score FROM 2-17978030-6 WHERE Set 3 = '26–28'\n", "END\n", "\n", "from model\n", "table: 2-12487101-6\n", "columns: Position,Number,Name,Class,Injury (Status),Last Update\n", "Q: for the position of rb what is the name?\n", "A: SELECT Name FROM 2-12487101-6 WHERE Position = 'rb'\n", "END\n", "\\end{code}\n", "\n", "Comment: I'm not sure what you mean by \"the name of\n", "\n", "expected answer\n", "SELECT Name FROM 2-12487101-6 WHERE Position = 'rb'\n", "END\n", "\n", "from model\n", "table: 2-16780011-2\n", "columns: Week,Date,Opponent,Result,TV Time,Attendance\n", "Q: Which week had an attendance of 55,158?\n", "A: SELECT Attendance FROM 2-16780011-2 WHERE Week = '2019-01-06'\n", "END\n", "\\end{code}\n", "\n", "\n", "\n", "expected answer\n", "SELECT SUM Week FROM 2-16780011-2 WHERE Attendance = '55,158'\n", "END\n", "\n", "from model\n", "table: 1-1671401-1\n", "columns: Year,Starts,Wins,Top 5,Top 10,Poles,Avg. Start,Avg. Finish,Winnings,Position,Team(s)\n", "Q: Name the least top 5\n", "A: SELECT MIN Top 5 FROM 1-1671401-1\n", "END\n", "\\end{code}\n", "\n", "\n", "expected answer\n", "SELECT MIN Top 5 FROM 1-1671401-1\n", "END\n", "\n", "from model\n", "table: 2-17540875-4\n", "columns: Tie no,Home team,Score,Away team,Date\n", "Q: What team was the home team when Manchester City was the away team?\n", "A: SELECT Tie no,Home team FROM 2-17540875-4 WHERE Away team = 'Manchester City'\n", "END\n", "\n", "Comment: @user378326\n", "\n", "expected answer\n", "SELECT Home team FROM 2-17540875-4 WHERE Away team = 'manchester city'\n", "END\n", "\n", "from model\n", "table: 2-15524351-11\n", "columns: Date,Opponent,Location,Result,Attendance\n", "Q: What team was the opponent when the result was l 0-27?\n", "A: SELECT Opponent FROM 2-15524351-11 WHERE Result = 'l 0-27'\n", "END\n", "\\end{code}\n", "\n", "Comment\n", "\n", "expected answer\n", "SELECT Opponent FROM 2-15524351-11 WHERE Result = 'l 0-27'\n", "END\n", "\n", "from model\n", "table: 1-12962773-1\n", "columns: No,Player,Height,Position,Year born,Current Club\n", "Q: what's current club with height being 2.09\n", "A: SELECT Current Club\n", "FROM 1-12962773-1\n", "WHERE Height = 2.09\n", "END\n", "\\end{code}\n", "\n", "Comment: I'm\n", "\n", "expected answer\n", "SELECT Current Club FROM 1-12962773-1 WHERE Height = '2.09'\n", "END\n", "\n", "from model\n", "table: 1-12441518-1\n", "columns: Character,Portrayed by,Main cast seasons,Recurring cast seasons,# of episodes\n", "Q: How many people named Nick Lucas are on the show?\n", "A: SELECT COUNT(Character) FROM 1-12441518-1 WHERE Character = 'Nick Lucas'\n", "END\n", "\n", "Comment: @user3783264,\n", "\n", "expected answer\n", "SELECT COUNT Portrayed by FROM 1-12441518-1 WHERE Character = 'Nick Lucas'\n", "END\n", "\n", "from model\n", "table: 1-17176509-4\n", "columns: Position,Driver / Passenger,Equipment,Bike No,Points\n", "Q: Name the driver/passenger for 30\n", "A: SELECT Driver / Passenger\n", "FROM 1-17176509-4\n", "WHERE 30 = 'Driver / Passenger'\n", "END\n", "\n", "Comment: I'm not sure what you mean by\n", "\n", "expected answer\n", "SELECT COUNT Driver / Passenger FROM 1-17176509-4 WHERE Position = 30\n", "END\n", "\n" ] } ], "source": [ "def get_query(q):\n", " \n", " toks = tokenizer(q , return_tensors='pt')\n", " ctoks = toks.input_ids.to('cuda')\n", " gen = model.generate(ctoks, max_length=100)\n", " return tokenizer.decode(gen[0])\n", "\n", "M = len(q_red)\n", "\n", "for _ in range(10):\n", " j = random.randint(0,M-1)\n", " qs = q_red[j]\n", " a = a_red[j]\n", "\n", " ma = get_query(qs)\n", "\n", " #print(qs)\n", " print('from model')\n", " print(ma)\n", " print()\n", " print('expected answer')\n", " print(a)\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90" } } }, "nbformat": 4, "nbformat_minor": 2 }