{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6358c104eb744a6f807157a551f13094", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00 3\n", "END\n", "\n", "\n", "table: 1-28962227-1\n", "columns: Series,Premiere,Finale,Runners-up,Winner\n", "Q: What is the date of the finale where Holly Bell was runner-up?\n", "A: SELECT Finale FROM 1-28962227-1 WHERE Runners-up = 'Holly Bell'\n", "END\n", "\n", "\n", "table: 2-10652530-2\n", "columns: Week,Date,Opponent,Result,Stadium,Record,Attendance\n", "Q: What was the Browns record after they played the game at the Paul Brown stadium?\n", "A: SELECT Record FROM 2-10652530-2 WHERE Stadium = 'paul brown stadium'\n", "END\n", "\n", "\n", "table: 2-18379129-4\n", "columns: play,author,company,base,country\n", "Q: Who is the author of the Play Electra?\n", "A: SELECT author FROM 2-18379129-4 WHERE play = 'electra'\n", "END\n", "\n", "\n", "table: 2-16158579-1\n", "columns: School year,95-96,99-00,00-01,01-02,02-03,03-04,04-05,05-06,06-07\n", "Q: What is 02-03, when School Year is % Learning In Latvian?\n", "A: SELECT 02-03 FROM 2-16158579-1 WHERE School year = '% learning in latvian'\n", "END\n", "\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1ce4b83c8c7d4495b5b31732c77862ea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00\n", " \n", " \n", " [500/500 7:38:36, Epoch 1/2]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
12.748800
22.723800
32.737600
42.707100
52.692800
62.720700
72.681400
82.736400
92.701800
102.711700
112.685800
122.684300
132.686300
142.698800
152.659300
162.688900
172.661800
182.677700
192.647100
202.679800
212.652000
222.628900
232.656100
242.669100
252.667800
262.636300
272.616800
282.630600
292.621000
302.602000
312.607900
322.635800
332.594600
342.604400
352.618900
362.563400
372.589200
382.552100
392.583600
402.554500
412.557400
422.536700
432.535000
442.557900
452.530100
462.527900
472.510100
482.539100
492.500100
502.536200
512.487100
522.521700
532.532600
542.494500
552.468900
562.468700
572.474300
582.480900
592.442800
602.472800
612.452900
622.452000
632.443100
642.446700
652.415100
662.376300
672.411500
682.403900
692.383800
702.427800
712.419400
722.371900
732.364400
742.360000
752.337600
762.332800
772.315700
782.344200
792.331700
802.303100
812.324700
822.285900
832.268000
842.260600
852.286100
862.233600
872.266200
882.217000
892.249300
902.239000
912.221900
922.223300
932.179500
942.204400
952.193200
962.163800
972.158200
982.127700
992.141400
1002.121400
1012.115500
1022.125200
1032.140100
1042.118400
1052.110400
1062.097300
1072.071400
1082.083400
1092.090200
1102.078200
1112.061100
1122.047500
1132.006100
1142.023800
1152.014000
1162.008800
1171.988800
1181.984900
1191.971000
1201.924100
1211.953100
1221.957800
1231.952500
1241.890400
1251.915900
1261.901100
1271.879900
1281.834100
1291.855900
1301.853800
1311.869200
1321.821400
1331.835100
1341.817700
1351.785800
1361.764000
1371.796800
1381.751100
1391.756500
1401.789900
1411.773100
1421.729200
1431.700200
1441.721200
1451.690600
1461.687700
1471.743500
1481.690000
1491.687200
1501.663000
1511.648600
1521.667100
1531.665600
1541.647000
1551.629500
1561.620800
1571.616400
1581.658500
1591.593900
1601.604300
1611.621200
1621.607900
1631.591100
1641.598100
1651.579700
1661.545500
1671.582100
1681.568300
1691.557900
1701.561300
1711.521800
1721.542500
1731.502300
1741.513900
1751.501500
1761.551200
1771.495600
1781.504000
1791.512500
1801.488200
1811.492200
1821.494300
1831.494800
1841.446100
1851.514700
1861.450900
1871.476900
1881.447100
1891.490800
1901.433200
1911.438100
1921.410500
1931.422600
1941.405500
1951.439400
1961.448100
1971.410200
1981.403800
1991.464400
2001.417700
2011.419500
2021.419400
2031.387700
2041.400400
2051.404700
2061.398400
2071.358000
2081.359600
2091.367700
2101.358600
2111.369200
2121.373700
2131.395100
2141.360800
2151.343900
2161.330300
2171.328800
2181.369900
2191.346300
2201.379700
2211.326000
2221.334600
2231.339100
2241.349200
2251.324800
2261.303600
2271.299900
2281.338800
2291.331800
2301.351400
2311.314200
2321.293600
2331.322100
2341.295800
2351.302500
2361.338900
2371.308900
2381.290100
2391.323300
2401.270500
2411.246300
2421.303900
2431.324800
2441.216000
2451.303500
2461.304900
2471.273300
2481.278300
2491.252000
2501.283400
2511.271600
2521.300300
2531.265800
2541.249200
2551.252600
2561.265500
2571.228600
2581.257300
2591.288900
2601.257200
2611.243700
2621.272100
2631.252000
2641.264900
2651.268800
2661.256000
2671.230200
2681.231700
2691.243400
2701.285200
2711.225500
2721.217900
2731.209200
2741.224200
2751.226400
2761.261500
2771.223900
2781.244000
2791.226600
2801.235000
2811.213400
2821.177600
2831.218100
2841.231900
2851.200900
2861.223400
2871.235100
2881.232500
2891.230100
2901.225900
2911.182700
2921.237100
2931.201000
2941.213000
2951.205500
2961.181900
2971.198300
2981.195200
2991.215000
3001.195500
3011.186100
3021.174900
3031.184400
3041.207100
3051.181100
3061.195300
3071.189000
3081.180200
3091.167200
3101.206700
3111.203600
3121.186600
3131.224100
3141.180000
3151.186600
3161.150700
3171.165700
3181.178100
3191.148300
3201.153600
3211.189200
3221.182100
3231.183800
3241.202900
3251.196600
3261.200800
3271.153100
3281.212400
3291.167300
3301.188300
3311.179300
3321.211400
3331.169900
3341.179300
3351.153300
3361.188900
3371.179200
3381.217300
3391.169700
3401.177700
3411.197300
3421.177800
3431.169700
3441.186800
3451.180000
3461.193400
3471.171900
3481.190000
3491.160900
3501.170800
3511.166900
3521.183200
3531.118200
3541.185900
3551.157800
3561.160200
3571.184200
3581.172100
3591.143800
3601.178000
3611.157900
3621.151700
3631.196600
3641.181800
3651.195600
3661.165000
3671.157300
3681.165200
3691.167700
3701.184900
3711.168400
3721.150500
3731.152900
3741.158900
3751.143900
3761.157200
3771.146800
3781.142600
3791.140600
3801.142400
3811.114100
3821.169700
3831.142500
3841.176000
3851.160600
3861.164700
3871.124000
3881.134500
3891.185500
3901.154300
3911.125500
3921.174400
3931.132800
3941.145200
3951.129800
3961.140600
3971.126000
3981.182800
3991.127800
4001.155000
4011.134600
4021.155900
4031.150400
4041.141700
4051.131500
4061.169600
4071.170500
4081.129100
4091.151700
4101.168200
4111.109100
4121.129700
4131.143900
4141.157300
4151.128900
4161.171500
4171.141600
4181.157700
4191.137000
4201.154000
4211.167300
4221.137400
4231.121500
4241.128500
4251.130300
4261.162100
4271.155100
4281.145300
4291.121000
4301.182200
4311.157000
4321.162300
4331.135200
4341.141300
4351.151700
4361.148000
4371.132500
4381.163000
4391.116300
4401.142000
4411.091700
4421.141500
4431.154900
4441.120400
4451.173700
4461.138300
4471.135600
4481.138800
4491.126800
4501.129400
4511.146300
4521.104200
4531.163500
4541.169300
4551.147100
4561.157100
4571.122100
4581.121900
4591.150500
4601.115700
4611.121100
4621.123400
4631.097500
4641.103800
4651.167700
4661.130000
4671.164500
4681.127200
4691.133800
4701.132700
4711.122800
4721.159500
4731.122900
4741.105000
4751.145700
4761.086400
4771.112600
4781.139300
4791.135000
4801.135200
4811.117500
4821.102300
4831.147700
4841.119200
4851.125800
4861.135400
4871.149500
4881.099400
4891.153900
4901.122700
4911.089400
4921.167200
4931.151300
4941.131400
4951.131400
4961.145200
4971.125700
4981.119300
4991.128600
5001.121000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer = transformers.Trainer(\n", " model = model,\n", " train_dataset = data,\n", " args = targs,\n", " data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", ")\n", "trainer.train(resume_from_checkpoint=False)\n", "model.save_pretrained('sqllama-out3')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n", " \"You have modified the pretrained model configuration to control generation. This is a\"\n", "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", " warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "from model\n", "table: 2-11561331-17\n", "columns: Name,Actual version,System,Platform,License\n", "Q: Which System's Name is Steem, and has a Freeware License?\n", "A: SELECT Name FROM 2-11561331-17 WHERE License = 'Freeware' AND System = 'Steem'\n", "END\n", "\\end{code}\n", "\n", "\n", "\n", "expected answer\n", "SELECT System FROM 2-11561331-17 WHERE License = 'freeware' AND Name = 'steem'\n", "END\n", "\n", "from model\n", "table: 1-18847736-2\n", "columns: Game,Date,Opponent,Result,Dolphins points,Opponents,Record,Attendance\n", "Q: What is the date when the opponent is the New England Patriots?\n", "A: SELECT Date FROM 1-18847736-2 WHERE Opponent = 'New England Patriots'\n", "END\n", "\\end\n", "\n", "expected answer\n", "SELECT Date FROM 1-18847736-2 WHERE Opponent = 'New England Patriots'\n", "END\n", "\n", "from model\n", "table: 1-12028543-3\n", "columns: Season,Cup FinalDate,WinningTeam,Score,LosingTeam,Location,Cup Final Attendance\n", "Q: Who was the winning team in the 1989 season?\n", "A: SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n", "END\n", "\n", "Comment: I'm\n", "\n", "expected answer\n", "SELECT WinningTeam FROM 1-12028543-3 WHERE Season = '1989'\n", "END\n", "\n", "from model\n", "table: 2-18096431-5\n", "columns: Place,Player,Country,Score,To par\n", "Q: What is To par, when Country is \"United States\", and when Player is \"Mark Brooks\"?\n", "A: 18\n", "END\n", "\\end{code}\n", "\n", "I want to get the value of To par, when Country is \"United States\", and when Player is \"Mark Brooks\".\n", "\n", "I\n", "\n", "expected answer\n", "SELECT To par FROM 2-18096431-5 WHERE Country = 'united states' AND Player = 'mark brooks'\n", "END\n", "\n", "from model\n", "table: 2-10701914-2\n", "columns: Home team,Home team score,Away team,Away team score,Venue,Crowd,Date\n", "Q: What home team played at the western oval?\n", "A: SELECT Home team FROM 2-10701914-2 WHERE Venue = 'western oval'\n", "END\n", "\\end{code}\n", "\n", "Comment:\n", "\n", "expected answer\n", "SELECT Home team FROM 2-10701914-2 WHERE Venue = 'western oval'\n", "END\n", "\n", "from model\n", "table: 1-29598261-1\n", "columns: Name,Number,Position,Height,Weight,Year,Hometown,Last School/College\n", "Q: what is the school for chris mcnamara?\n", "A: SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'chris mcnamara'\n", "END\n", "\\end{code}\n", "\n", "\n", "expected answer\n", "SELECT Last School/College FROM 1-29598261-1 WHERE Name = 'Chris McNamara'\n", "END\n", "\n", "from model\n", "table: 1-27722408-11\n", "columns: Game,Date,Team,Score,High points,High rebounds,High assists,Location Attendance,Record\n", "Q: Who had the most assists and how many did they have on April 8?\n", "A: SELECT High assists FROM 1-27722408-11 WHERE Date = 'april 8' AND High assists\n", "\n", "expected answer\n", "SELECT High assists FROM 1-27722408-11 WHERE Date = 'April 8'\n", "END\n", "\n", "from model\n", "table: 1-21378339-5\n", "columns: Draw,Song,Artist,Panel Points,Televotes,Televote Points,Score,Placing\n", "Q: Name the number of artists for panel points being 5\n", "A: SELECT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n", "END\n", "\\end{code}\n", "\n", "\n", "\n", "expected answer\n", "SELECT COUNT Artist FROM 1-21378339-5 WHERE Panel Points = 5\n", "END\n", "\n", "from model\n", "table: 2-11545282-17\n", "columns: Player,Nationality,Position,Years for Jazz,School/Club Team\n", "Q: What position does Michael Ruffin play?\n", "A: 11-17655555-1\n", "END\n", "\\end{code}\n", "\n", "Comment: I'm not sure what you mean by \"I want to get the position of Michael Ruffin\n", "\n", "expected answer\n", "SELECT Position FROM 2-11545282-17 WHERE Player = 'michael ruffin'\n", "END\n", "\n", "from model\n", "table: 1-17801022-1\n", "columns: Year,Date,Driver,Manufacturer,Laps,Miles (km),Race Time,Average Speed (mph)\n", "Q: What manufacturer won the race on November 2?\n", "A: SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'November 2'\n", "END\n", "\\end{\n", "\n", "expected answer\n", "SELECT Manufacturer FROM 1-17801022-1 WHERE Date = 'November 2'\n", "END\n", "\n" ] } ], "source": [ "def get_query(q):\n", " \n", " toks = tokenizer(q , return_tensors='pt')\n", " ctoks = toks.input_ids.to('cuda')\n", " gen = model.generate(ctoks, max_length=100)\n", " return tokenizer.decode(gen[0])\n", "\n", "M = len(q_red)\n", "\n", "for _ in range(10):\n", " j = random.randint(0,M-1)\n", " qs = q_red[j]\n", " a = a_red[j]\n", "\n", " ma = get_query(qs)\n", "\n", " #print(qs)\n", " print('from model')\n", " print(ma)\n", " print()\n", " print('expected answer')\n", " print(a)\n" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90" } } }, "nbformat": 4, "nbformat_minor": 2 }