diff --git "a/wikisql_inf.ipynb" "b/wikisql_inf.ipynb" new file mode 100644--- /dev/null +++ "b/wikisql_inf.ipynb" @@ -0,0 +1,1561 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "torch.cuda.is_available()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", + " warn(msg)\n", + "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", + "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", + "The class this function is called from is 'LlamaTokenizer'.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "===================================BUG REPORT===================================\n", + "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", + "================================================================================\n", + "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", + "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", + "CUDA SETUP: Detected CUDA version 113\n", + "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n", + "PeftConfig(peft_type='LORA', base_model_name_or_path='decapoda-research/llama-7b-hf', task_type='CASUAL_LM', inference_mode=True)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "855476d6c7594e6891de62b2848d7858", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/33 [00:00', '<', 'OP']\n", + "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", + "\n", + "def fix_repr(d,cols,types,tid):\n", + " sel_index=d['sel'] \n", + " agg_index=d['agg']\n", + " conditions=d['conds']\n", + " col = cols[sel_index]\n", + " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", + " agg=agg_ops[agg_index],\n", + " sel=col,\n", + " tid=tid\n", + " )\n", + " if conditions:\n", + " cs = []\n", + " for i, o, v in conditions:\n", + " #print(i,cols)\n", + " nm = cols[i]\n", + " op = cond_ops[o]\n", + " \n", + " if types[i] in ['text']:\n", + " val = f\"\\'{v}\\'\"\n", + " else:\n", + " val = v\n", + " cs.append(f'{nm} {op} {val}')\n", + " #print(cs)\n", + "\n", + " rep += ' WHERE ' + ' AND '.join(cs)\n", + " \n", + " return rep\n", + "\n", + "tbl_cols = {}\n", + "tbl_types = {}\n", + "tbl_str = {}\n", + "\n", + "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", + "\n", + "def tbl_def_to_string(id, header, types):\n", + " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", + " return s\n", + "\n", + "with open('data/test.tables.jsonl') as f:\n", + " for line in f:\n", + " js = json.loads(line)\n", + " id = js['id']\n", + " hdr = js['header']\n", + " ts = js['types']\n", + " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", + " tbl_cols[id] = hdr\n", + " tbl_types[id] = ts\n", + "\n", + "q_s = []\n", + "a_s = []\n", + "\n", + "with open('data/test.jsonl') as f:\n", + " for line in f:\n", + " js = json.loads(line)\n", + " id = js['table_id']\n", + " s = tbl_str[id]\n", + " qst = js['question']\n", + " nl = s + '\\nQ: ' + qst + '\\nA: '\n", + " q_s.append(nl)\n", + "\n", + " sql = js['sql']\n", + " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", + " a = a + \"\\nEND\\n\"\n", + " a_s.append(a)\n", + "\n", + "M = len(q_s)\n", + "\n", + "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", + "\n", + "for i in range(5):\n", + " j = random.randint(0,M-1)\n", + " print()\n", + " print(data_txt[j]) \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "toks = [tokenizer(s) for s in data_txt]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "86\n", + " 0\n", + "count 15878.000000\n", + "mean 101.114750\n", + "std 21.311681\n", + "min 62.000000\n", + "25% 87.000000\n", + "50% 97.000000\n", + "75% 109.000000\n", + "max 390.000000\n", + "8981\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "print(len(toks[0].input_ids))\n", + "lens = np.array([len(tok.input_ids) for tok in toks])\n", + "print(pd.DataFrame(lens).describe())\n", + "\n", + "z = zip(q_s,lens)\n", + "q_red = [a for a,b in z if b < 100]\n", + "z = zip(a_s,lens)\n", + "a_red = [a for a,b in z if b < 100]\n", + "\n", + "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", + "print(len(data_red))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1412: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cuda, whereas the model is on cpu. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cpu') before running `.generate()`.\n", + " UserWarning,\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Expected a cuda device, but got: cpu", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/tmp/ipykernel_2860/1693296982.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ma_red\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_query\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;31m#print(qs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/var/tmp/ipykernel_2860/1693296982.py\u001b[0m in \u001b[0;36mget_query\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtoks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'pt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mctoks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctoks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, streamer, **kwargs)\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0msynced_gpus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msynced_gpus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1446\u001b[0m \u001b[0mstreamer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstreamer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1447\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1448\u001b[0m )\n\u001b[1;32m 1449\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgreedy_search\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m 2250\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2251\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2252\u001b[0;31m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2253\u001b[0m )\n\u001b[1;32m 2254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 694\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 695\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 696\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 697\u001b[0m )\n\u001b[1;32m 698\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 583\u001b[0;31m \u001b[0muse_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 584\u001b[0m )\n\u001b[1;32m 585\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 298\u001b[0;31m \u001b[0muse_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 299\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresidual\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mbsz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 196\u001b[0;31m \u001b[0mquery_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mq_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbsz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_heads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 197\u001b[0m \u001b[0mkey_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbsz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_heads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0mvalue_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbsz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_heads\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/peft/tuners/lora.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 502\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 503\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable_adapters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/nn/modules.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbnb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhas_fp16_weights\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCB\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCxB\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/autograd/_functions.py\u001b[0m in \u001b[0;36mmatmul\u001b[0;34m(A, B, out, state, threshold, bias)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mthreshold\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthreshold\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mthreshold\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 488\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mMatMul8bitLt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/autograd/_functions.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, A, B, out, bias, state)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mB\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mMatmulLtState\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 273\u001b[0;31m \u001b[0musing_igemmlt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_device_capability\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m7\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforce_no_igemmlt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 274\u001b[0m \u001b[0;31m# default of pytorch behavior if inputs are empty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_empty\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36mget_device_capability\u001b[0;34m(device)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mmajor\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mminor\u001b[0m \u001b[0mcuda\u001b[0m \u001b[0mcapability\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \"\"\"\n\u001b[0;32m--> 357\u001b[0;31m \u001b[0mprop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_device_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 358\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmajor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mminor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36mget_device_properties\u001b[0;34m(device)\u001b[0m\n\u001b[1;32m 370\u001b[0m \"\"\"\n\u001b[1;32m 371\u001b[0m \u001b[0m_lazy_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# will define _get_device_properties\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 372\u001b[0;31m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_get_device_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptional\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 373\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mdevice_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Invalid device id\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/cuda/_utils.py\u001b[0m in \u001b[0;36m_get_device_index\u001b[0;34m(device, optional, allow_cpu)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Expected a cuda or cpu device, but got: {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'cuda'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Expected a cuda device, but got: {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_scripting\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Expected a cuda device, but got: cpu" + ] + } + ], + "source": [ + "def get_query(q):\n", + " \n", + " toks = tokenizer(q , return_tensors='pt')\n", + " ctoks = toks.input_ids.to('cuda')\n", + " gen = model.generate(ctoks, max_length=100)\n", + " return tokenizer.decode(gen[0])\n", + "\n", + "M = len(q_red)\n", + "\n", + "for _ in range(10):\n", + " j = random.randint(0,M-1)\n", + " qs = q_red[j]\n", + " a = a_red[j]\n", + "\n", + " ma = get_query(qs)\n", + "\n", + " #print(qs)\n", + " print('from model')\n", + " print(ma)\n", + " print()\n", + " print('expected answer')\n", + " print(a)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "encoding tokens\n", + "generating ids\n" + ] + }, + { + "ename": "OutOfMemoryError", + "evalue": "CUDA out of memory. Tried to allocate 10.69 GiB (GPU 0; 14.56 GiB total capacity; 12.15 GiB already allocated; 1.35 GiB free; 12.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/tmp/ipykernel_721/1193102512.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'generating ids'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgen_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mtoks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'decoding ids'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgen_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, streamer, **kwargs)\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0msynced_gpus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msynced_gpus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1446\u001b[0m \u001b[0mstreamer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstreamer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1447\u001b[0;31m \u001b[0;34m**\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1448\u001b[0m )\n\u001b[1;32m 1449\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgreedy_search\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m 2250\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2251\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2252\u001b[0;31m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2253\u001b[0m )\n\u001b[1;32m 2254\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 694\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 695\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 696\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 697\u001b[0m )\n\u001b[1;32m 698\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 583\u001b[0;31m \u001b[0muse_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 584\u001b[0m )\n\u001b[1;32m 585\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mresidual\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layernorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;31m# Self Attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1195\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/accelerate/hooks.py\u001b[0m in \u001b[0;36mnew_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_hf_hook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0mvariance\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrsqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariance\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvariance_epsilon\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 10.69 GiB (GPU 0; 14.56 GiB total capacity; 12.15 GiB already allocated; 1.35 GiB free; 12.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" + ] + } + ], + "source": [ + "print('encoding tokens')\n", + "toks = tokenizer(q_red, return_tensors='pt', padding=True).to('cuda')\n", + " \n", + "print('generating ids')\n", + "gen_ids = model.generate(**toks, max_length=100)\n", + "print('decoding ids')\n", + "res = tokenizer.batch_decode(gen_ids)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processing question 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/transformers/generation/utils.py:1220: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n", + " \"You have modified the pretrained model configuration to control generation. This is a\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processing question 10\n", + "processing question 20\n", + "processing question 30\n", + "processing question 40\n", + "processing question 50\n", + "processing question 60\n", + "processing question 70\n", + "processing question 80\n", + "processing question 90\n", + "processing question 100\n", + "processing question 110\n", + "processing question 120\n", + "processing question 130\n", + "processing question 140\n", + "processing question 150\n", + "processing question 160\n", + "processing question 170\n", + "processing question 180\n", + "processing question 190\n", + "processing question 200\n", + "processing question 210\n", + "processing question 220\n", + "processing question 230\n", + "processing question 240\n", + "processing question 250\n", + "processing question 260\n", + "processing question 270\n", + "processing question 280\n", + "processing question 290\n", + "processing question 300\n", + "processing question 310\n", + "processing question 320\n", + "processing question 330\n", + "processing question 340\n", + "processing question 350\n", + "processing question 360\n", + "processing question 370\n", + "processing question 380\n", + "processing question 390\n", + "processing question 400\n", + "processing question 410\n", + "processing question 420\n", + "processing question 430\n", + "processing question 440\n", + "processing question 450\n", + "processing question 460\n", + "processing question 470\n", + "processing question 480\n", + "processing question 490\n", + "processing question 500\n", + "processing question 510\n", + "processing question 520\n", + "processing question 530\n", + "processing question 540\n", + "processing question 550\n", + "processing question 560\n", + "processing question 570\n", + "processing question 580\n", + "processing question 590\n", + "processing question 600\n", + "processing question 610\n", + "processing question 620\n", + "processing question 630\n", + "processing question 640\n", + "processing question 650\n", + "processing question 660\n", + "processing question 670\n", + "processing question 680\n", + "processing question 690\n", + "processing question 700\n", + "processing question 710\n", + "processing question 720\n", + "processing question 730\n", + "processing question 740\n", + "processing question 750\n", + "processing question 760\n", + "processing question 770\n", + "processing question 780\n", + "processing question 790\n", + "processing question 800\n", + "processing question 810\n", + "processing question 820\n", + "processing question 830\n", + "processing question 840\n", + "processing question 850\n", + "processing question 860\n", + "processing question 870\n", + "processing question 880\n", + "processing question 890\n", + "processing question 900\n", + "processing question 910\n", + "processing question 920\n", + "processing question 930\n", + "processing question 940\n", + "processing question 950\n", + "processing question 960\n", + "processing question 970\n", + "processing question 980\n", + "processing question 990\n", + "processing question 1000\n", + "processing question 1010\n", + "processing question 1020\n", + "processing question 1030\n", + "processing question 1040\n", + "processing question 1050\n", + "processing question 1060\n", + "processing question 1070\n", + "processing question 1080\n", + "processing question 1090\n", + "processing question 1100\n", + "processing question 1110\n", + "processing question 1120\n", + "processing question 1130\n", + "processing question 1140\n", + "processing question 1150\n", + "processing question 1160\n", + "processing question 1170\n", + "processing question 1180\n", + "processing question 1190\n", + "processing question 1200\n", + "processing question 1210\n", + "processing question 1220\n", + "processing question 1230\n", + "processing question 1240\n", + "processing question 1250\n", + "processing question 1260\n", + "processing question 1270\n", + "processing question 1280\n", + "processing question 1290\n", + "processing question 1300\n", + "processing question 1310\n", + "processing question 1320\n", + "processing question 1330\n", + "processing question 1340\n", + "processing question 1350\n", + "processing question 1360\n", + "processing question 1370\n", + "processing question 1380\n", + "processing question 1390\n", + "processing question 1400\n", + "processing question 1410\n", + "processing question 1420\n", + "processing question 1430\n", + "processing question 1440\n", + "processing question 1450\n", + "processing question 1460\n", + "processing question 1470\n", + "processing question 1480\n", + "processing question 1490\n", + "processing question 1500\n", + "processing question 1510\n", + "processing question 1520\n", + "processing question 1530\n", + "processing question 1540\n", + "processing question 1550\n", + "processing question 1560\n", + "processing question 1570\n", + "processing question 1580\n", + "processing question 1590\n", + "processing question 1600\n", + "processing question 1610\n", + "processing question 1620\n", + "processing question 1630\n", + "processing question 1640\n", + "processing question 1650\n", + "processing question 1660\n", + "processing question 1670\n", + "processing question 1680\n", + "processing question 1690\n", + "processing question 1700\n", + "processing question 1710\n", + "processing question 1720\n", + "processing question 1730\n", + "processing question 1740\n", + "processing question 1750\n", + "processing question 1760\n", + "processing question 1770\n", + "processing question 1780\n", + "processing question 1790\n", + "processing question 1800\n", + "processing question 1810\n", + "processing question 1820\n", + "processing question 1830\n", + "processing question 1840\n", + "processing question 1850\n", + "processing question 1860\n", + "processing question 1870\n", + "processing question 1880\n", + "processing question 1890\n", + "processing question 1900\n", + "processing question 1910\n", + "processing question 1920\n", + "processing question 1930\n", + "processing question 1940\n", + "processing question 1950\n", + "processing question 1960\n", + "processing question 1970\n", + "processing question 1980\n", + "processing question 1990\n", + "processing question 2000\n", + "processing question 2010\n", + "processing question 2020\n", + "processing question 2030\n", + "processing question 2040\n", + "processing question 2050\n", + "processing question 2060\n", + "processing question 2070\n", + "processing question 2080\n", + "processing question 2090\n", + "processing question 2100\n", + "processing question 2110\n", + "processing question 2120\n", + "processing question 2130\n", + "processing question 2140\n", + "processing question 2150\n", + "processing question 2160\n", + "processing question 2170\n", + "processing question 2180\n", + "processing question 2190\n", + "processing question 2200\n", + "processing question 2210\n", + "processing question 2220\n", + "processing question 2230\n", + "processing question 2240\n", + "processing question 2250\n", + "processing question 2260\n", + "processing question 2270\n", + "processing question 2280\n", + "processing question 2290\n", + "processing question 2300\n", + "processing question 2310\n", + "processing question 2320\n", + "processing question 2330\n", + "processing question 2340\n", + "processing question 2350\n", + "processing question 2360\n", + "processing question 2370\n", + "processing question 2380\n", + "processing question 2390\n", + "processing question 2400\n", + "processing question 2410\n", + "processing question 2420\n", + "processing question 2430\n", + "processing question 2440\n", + "processing question 2450\n", + "processing question 2460\n", + "processing question 2470\n", + "processing question 2480\n", + "processing question 2490\n", + "processing question 2500\n", + "processing question 2510\n", + "processing question 2520\n", + "processing question 2530\n", + "processing question 2540\n", + "processing question 2550\n", + "processing question 2560\n", + "processing question 2570\n", + "processing question 2580\n", + "processing question 2590\n", + "processing question 2600\n", + "processing question 2610\n", + "processing question 2620\n", + "processing question 2630\n", + "processing question 2640\n", + "processing question 2650\n", + "processing question 2660\n", + "processing question 2670\n", + "processing question 2680\n", + "processing question 2690\n", + "processing question 2700\n", + "processing question 2710\n", + "processing question 2720\n", + "processing question 2730\n", + "processing question 2740\n", + "processing question 2750\n", + "processing question 2760\n", + "processing question 2770\n", + "processing question 2780\n", + "processing question 2790\n", + "processing question 2800\n", + "processing question 2810\n", + "processing question 2820\n", + "processing question 2830\n", + "processing question 2840\n", + "processing question 2850\n", + "processing question 2860\n", + "processing question 2870\n", + "processing question 2880\n", + "processing question 2890\n", + "processing question 2900\n", + "processing question 2910\n", + "processing question 2920\n", + "processing question 2930\n", + "processing question 2940\n", + "processing question 2950\n", + "processing question 2960\n", + "processing question 2970\n", + "processing question 2980\n", + "processing question 2990\n", + "processing question 3000\n", + "processing question 3010\n", + "processing question 3020\n", + "processing question 3030\n", + "processing question 3040\n", + "processing question 3050\n", + "processing question 3060\n", + "processing question 3070\n", + "processing question 3080\n", + "processing question 3090\n", + "processing question 3100\n", + "processing question 3110\n", + "processing question 3120\n", + "processing question 3130\n", + "processing question 3140\n", + "processing question 3150\n", + "processing question 3160\n", + "processing question 3170\n", + "processing question 3180\n", + "processing question 3190\n", + "processing question 3200\n", + "processing question 3210\n", + "processing question 3220\n", + "processing question 3230\n", + "processing question 3240\n", + "processing question 3250\n", + "processing question 3260\n", + "processing question 3270\n", + "processing question 3280\n", + "processing question 3290\n", + "processing question 3300\n", + "processing question 3310\n", + "processing question 3320\n", + "processing question 3330\n", + "processing question 3340\n", + "processing question 3350\n", + "processing question 3360\n", + "processing question 3370\n", + "processing question 3380\n", + "processing question 3390\n", + "processing question 3400\n", + "processing question 3410\n", + "processing question 3420\n", + "processing question 3430\n", + "processing question 3440\n", + "processing question 3450\n", + "processing question 3460\n", + "processing question 3470\n", + "processing question 3480\n", + "processing question 3490\n", + "processing question 3500\n", + "processing question 3510\n", + "processing question 3520\n", + "processing question 3530\n", + "processing question 3540\n", + "processing question 3550\n", + "processing question 3560\n", + "processing question 3570\n", + "processing question 3580\n", + "processing question 3590\n", + "processing question 3600\n", + "processing question 3610\n", + "processing question 3620\n", + "processing question 3630\n", + "processing question 3640\n", + "processing question 3650\n", + "processing question 3660\n", + "processing question 3670\n", + "processing question 3680\n", + "processing question 3690\n", + "processing question 3700\n", + "processing question 3710\n", + "processing question 3720\n", + "processing question 3730\n", + "processing question 3740\n", + "processing question 3750\n", + "processing question 3760\n", + "processing question 3770\n", + "processing question 3780\n", + "processing question 3790\n", + "processing question 3800\n", + "processing question 3810\n", + "processing question 3820\n", + "processing question 3830\n", + "processing question 3840\n", + "processing question 3850\n", + "processing question 3860\n", + "processing question 3870\n", + "processing question 3880\n", + "processing question 3890\n", + "processing question 3900\n", + "processing question 3910\n", + "processing question 3920\n", + "processing question 3930\n", + "processing question 3940\n", + "processing question 3950\n", + "processing question 3960\n", + "processing question 3970\n", + "processing question 3980\n", + "processing question 3990\n", + "processing question 4000\n", + "processing question 4010\n", + "processing question 4020\n", + "processing question 4030\n", + "processing question 4040\n", + "processing question 4050\n", + "processing question 4060\n", + "processing question 4070\n", + "processing question 4080\n", + "processing question 4090\n", + "processing question 4100\n", + "processing question 4110\n", + "processing question 4120\n", + "processing question 4130\n", + "processing question 4140\n", + "processing question 4150\n", + "processing question 4160\n", + "processing question 4170\n", + "processing question 4180\n", + "processing question 4190\n", + "processing question 4200\n", + "processing question 4210\n", + "processing question 4220\n", + "processing question 4230\n", + "processing question 4240\n", + "processing question 4250\n", + "processing question 4260\n", + "processing question 4270\n", + "processing question 4280\n", + "processing question 4290\n", + "processing question 4300\n", + "processing question 4310\n", + "processing question 4320\n", + "processing question 4330\n", + "processing question 4340\n", + "processing question 4350\n", + "processing question 4360\n", + "processing question 4370\n", + "processing question 4380\n", + "processing question 4390\n", + "processing question 4400\n", + "processing question 4410\n", + "processing question 4420\n", + "processing question 4430\n", + "processing question 4440\n", + "processing question 4450\n", + "processing question 4460\n", + "processing question 4470\n", + "processing question 4480\n", + "processing question 4490\n", + "processing question 4500\n", + "processing question 4510\n", + "processing question 4520\n", + "processing question 4530\n", + "processing question 4540\n", + "processing question 4550\n", + "processing question 4560\n", + "processing question 4570\n", + "processing question 4580\n", + "processing question 4590\n", + "processing question 4600\n", + "processing question 4610\n", + "processing question 4620\n", + "processing question 4630\n", + "processing question 4640\n", + "processing question 4650\n", + "processing question 4660\n", + "processing question 4670\n", + "processing question 4680\n", + "processing question 4690\n", + "processing question 4700\n", + "processing question 4710\n", + "processing question 4720\n", + "processing question 4730\n", + "processing question 4740\n", + "processing question 4750\n", + "processing question 4760\n", + "processing question 4770\n", + "processing question 4780\n", + "processing question 4790\n", + "processing question 4800\n", + "processing question 4810\n", + "processing question 4820\n", + "processing question 4830\n", + "processing question 4840\n", + "processing question 4850\n", + "processing question 4860\n", + "processing question 4870\n", + "processing question 4880\n", + "processing question 4890\n", + "processing question 4900\n", + "processing question 4910\n", + "processing question 4920\n", + "processing question 4930\n", + "processing question 4940\n", + "processing question 4950\n", + "processing question 4960\n", + "processing question 4970\n", + "processing question 4980\n", + "processing question 4990\n", + "processing question 5000\n", + "processing question 5010\n", + "processing question 5020\n", + "processing question 5030\n", + "processing question 5040\n", + "processing question 5050\n", + "processing question 5060\n", + "processing question 5070\n", + "processing question 5080\n", + "processing question 5090\n", + "processing question 5100\n", + "processing question 5110\n", + "processing question 5120\n", + "processing question 5130\n", + "processing question 5140\n", + "processing question 5150\n", + "processing question 5160\n", + "processing question 5170\n", + "processing question 5180\n", + "processing question 5190\n", + "processing question 5200\n", + "processing question 5210\n", + "processing question 5220\n", + "processing question 5230\n", + "processing question 5240\n", + "processing question 5250\n", + "processing question 5260\n", + "processing question 5270\n", + "processing question 5280\n", + "processing question 5290\n", + "processing question 5300\n", + "processing question 5310\n", + "processing question 5320\n", + "processing question 5330\n", + "processing question 5340\n", + "processing question 5350\n", + "processing question 5360\n", + "processing question 5370\n", + "processing question 5380\n", + "processing question 5390\n", + "processing question 5400\n", + "processing question 5410\n", + "processing question 5420\n", + "processing question 5430\n", + "processing question 5440\n", + "processing question 5450\n", + "processing question 5460\n", + "processing question 5470\n", + "processing question 5480\n", + "processing question 5490\n", + "processing question 5500\n", + "processing question 5510\n", + "processing question 5520\n", + "processing question 5530\n", + "processing question 5540\n", + "processing question 5550\n", + "processing question 5560\n", + "processing question 5570\n", + "processing question 5580\n", + "processing question 5590\n", + "processing question 5600\n", + "processing question 5610\n", + "processing question 5620\n", + "processing question 5630\n", + "processing question 5640\n", + "processing question 5650\n", + "processing question 5660\n", + "processing question 5670\n", + "processing question 5680\n", + "processing question 5690\n", + "processing question 5700\n", + "processing question 5710\n", + "processing question 5720\n", + "processing question 5730\n", + "processing question 5740\n", + "processing question 5750\n", + "processing question 5760\n", + "processing question 5770\n", + "processing question 5780\n", + "processing question 5790\n", + "processing question 5800\n", + "processing question 5810\n", + "processing question 5820\n", + "processing question 5830\n", + "processing question 5840\n", + "processing question 5850\n", + "processing question 5860\n", + "processing question 5870\n", + "processing question 5880\n", + "processing question 5890\n", + "processing question 5900\n", + "processing question 5910\n", + "processing question 5920\n", + "processing question 5930\n", + "processing question 5940\n", + "processing question 5950\n", + "processing question 5960\n", + "processing question 5970\n", + "processing question 5980\n", + "processing question 5990\n", + "processing question 6000\n", + "processing question 6010\n", + "processing question 6020\n", + "processing question 6030\n", + "processing question 6040\n", + "processing question 6050\n", + "processing question 6060\n", + "processing question 6070\n", + "processing question 6080\n", + "processing question 6090\n", + "processing question 6100\n", + "processing question 6110\n", + "processing question 6120\n", + "processing question 6130\n", + "processing question 6140\n", + "processing question 6150\n", + "processing question 6160\n", + "processing question 6170\n", + "processing question 6180\n", + "processing question 6190\n", + "processing question 6200\n", + "processing question 6210\n", + "processing question 6220\n", + "processing question 6230\n", + "processing question 6240\n", + "processing question 6250\n", + "processing question 6260\n", + "processing question 6270\n", + "processing question 6280\n", + "processing question 6290\n", + "processing question 6300\n", + "processing question 6310\n", + "processing question 6320\n", + "processing question 6330\n", + "processing question 6340\n", + "processing question 6350\n", + "processing question 6360\n", + "processing question 6370\n", + "processing question 6380\n", + "processing question 6390\n", + "processing question 6400\n", + "processing question 6410\n", + "processing question 6420\n", + "processing question 6430\n", + "processing question 6440\n", + "processing question 6450\n", + "processing question 6460\n", + "processing question 6470\n", + "processing question 6480\n", + "processing question 6490\n", + "processing question 6500\n", + "processing question 6510\n", + "processing question 6520\n", + "processing question 6530\n", + "processing question 6540\n", + "processing question 6550\n", + "processing question 6560\n", + "processing question 6570\n", + "processing question 6580\n", + "processing question 6590\n", + "processing question 6600\n", + "processing question 6610\n", + "processing question 6620\n", + "processing question 6630\n", + "processing question 6640\n", + "processing question 6650\n", + "processing question 6660\n", + "processing question 6670\n", + "processing question 6680\n", + "processing question 6690\n", + "processing question 6700\n", + "processing question 6710\n", + "processing question 6720\n", + "processing question 6730\n", + "processing question 6740\n", + "processing question 6750\n", + "processing question 6760\n", + "processing question 6770\n", + "processing question 6780\n", + "processing question 6790\n", + "processing question 6800\n", + "processing question 6810\n", + "processing question 6820\n", + "processing question 6830\n", + "processing question 6840\n", + "processing question 6850\n", + "processing question 6860\n", + "processing question 6870\n", + "processing question 6880\n", + "processing question 6890\n", + "processing question 6900\n", + "processing question 6910\n", + "processing question 6920\n", + "processing question 6930\n", + "processing question 6940\n", + "processing question 6950\n", + "processing question 6960\n", + "processing question 6970\n", + "processing question 6980\n", + "processing question 6990\n", + "processing question 7000\n", + "processing question 7010\n", + "processing question 7020\n", + "processing question 7030\n", + "processing question 7040\n", + "processing question 7050\n", + "processing question 7060\n", + "processing question 7070\n", + "processing question 7080\n", + "processing question 7090\n", + "processing question 7100\n", + "processing question 7110\n", + "processing question 7120\n", + "processing question 7130\n", + "processing question 7140\n", + "processing question 7150\n", + "processing question 7160\n", + "processing question 7170\n", + "processing question 7180\n", + "processing question 7190\n", + "processing question 7200\n", + "processing question 7210\n", + "processing question 7220\n", + "processing question 7230\n", + "processing question 7240\n", + "processing question 7250\n", + "processing question 7260\n", + "processing question 7270\n", + "processing question 7280\n", + "processing question 7290\n", + "processing question 7300\n", + "processing question 7310\n", + "processing question 7320\n", + "processing question 7330\n", + "processing question 7340\n", + "processing question 7350\n", + "processing question 7360\n", + "processing question 7370\n", + "processing question 7380\n", + "processing question 7390\n", + "processing question 7400\n", + "processing question 7410\n", + "processing question 7420\n", + "processing question 7430\n", + "processing question 7440\n", + "processing question 7450\n", + "processing question 7460\n", + "processing question 7470\n", + "processing question 7480\n", + "processing question 7490\n", + "processing question 7500\n", + "processing question 7510\n", + "processing question 7520\n", + "processing question 7530\n", + "processing question 7540\n", + "processing question 7550\n", + "processing question 7560\n", + "processing question 7570\n", + "processing question 7580\n", + "processing question 7590\n", + "processing question 7600\n", + "processing question 7610\n", + "processing question 7620\n", + "processing question 7630\n", + "processing question 7640\n", + "processing question 7650\n", + "processing question 7660\n", + "processing question 7670\n", + "processing question 7680\n", + "processing question 7690\n", + "processing question 7700\n", + "processing question 7710\n", + "processing question 7720\n", + "processing question 7730\n", + "processing question 7740\n", + "processing question 7750\n", + "processing question 7760\n", + "processing question 7770\n", + "processing question 7780\n", + "processing question 7790\n", + "processing question 7800\n", + "processing question 7810\n", + "processing question 7820\n", + "processing question 7830\n", + "processing question 7840\n", + "processing question 7850\n", + "processing question 7860\n", + "processing question 7870\n", + "processing question 7880\n", + "processing question 7890\n", + "processing question 7900\n", + "processing question 7910\n", + "processing question 7920\n", + "processing question 7930\n", + "processing question 7940\n", + "processing question 7950\n", + "processing question 7960\n", + "processing question 7970\n", + "processing question 7980\n", + "processing question 7990\n", + "processing question 8000\n", + "processing question 8010\n", + "processing question 8020\n", + "processing question 8030\n", + "processing question 8040\n", + "processing question 8050\n", + "processing question 8060\n", + "processing question 8070\n", + "processing question 8080\n", + "processing question 8090\n", + "processing question 8100\n", + "processing question 8110\n", + "processing question 8120\n", + "processing question 8130\n", + "processing question 8140\n", + "processing question 8150\n", + "processing question 8160\n", + "processing question 8170\n", + "processing question 8180\n", + "processing question 8190\n", + "processing question 8200\n", + "processing question 8210\n", + "processing question 8220\n", + "processing question 8230\n", + "processing question 8240\n", + "processing question 8250\n", + "processing question 8260\n", + "processing question 8270\n", + "processing question 8280\n", + "processing question 8290\n", + "processing question 8300\n", + "processing question 8310\n", + "processing question 8320\n", + "processing question 8330\n", + "processing question 8340\n", + "processing question 8350\n", + "processing question 8360\n", + "processing question 8370\n", + "processing question 8380\n", + "processing question 8390\n", + "processing question 8400\n", + "processing question 8410\n", + "processing question 8420\n", + "processing question 8430\n", + "processing question 8440\n", + "processing question 8450\n", + "processing question 8460\n", + "processing question 8470\n", + "processing question 8480\n", + "processing question 8490\n", + "processing question 8500\n", + "processing question 8510\n", + "processing question 8520\n", + "processing question 8530\n", + "processing question 8540\n", + "processing question 8550\n", + "processing question 8560\n", + "processing question 8570\n", + "processing question 8580\n", + "processing question 8590\n", + "processing question 8600\n", + "processing question 8610\n", + "processing question 8620\n", + "processing question 8630\n", + "processing question 8640\n", + "processing question 8650\n", + "processing question 8660\n", + "processing question 8670\n", + "processing question 8680\n", + "processing question 8690\n", + "processing question 8700\n", + "processing question 8710\n", + "processing question 8720\n", + "processing question 8730\n", + "processing question 8740\n", + "processing question 8750\n", + "processing question 8760\n", + "processing question 8770\n", + "processing question 8780\n", + "processing question 8790\n", + "processing question 8800\n", + "processing question 8810\n", + "processing question 8820\n", + "processing question 8830\n", + "processing question 8840\n", + "processing question 8850\n", + "processing question 8860\n", + "processing question 8870\n", + "processing question 8880\n", + "processing question 8890\n", + "processing question 8900\n", + "processing question 8910\n", + "processing question 8920\n", + "processing question 8930\n", + "processing question 8940\n", + "processing question 8950\n", + "processing question 8960\n", + "processing question 8970\n", + "processing question 8980\n" + ] + } + ], + "source": [ + "model.eval()\n", + "\n", + "def get_query(q):\n", + " with torch.no_grad():\n", + " toks = tokenizer(q , return_tensors='pt')\n", + " ctoks = toks.input_ids.to('cuda')\n", + " gen = model.generate(ctoks, max_length=100)\n", + " return tokenizer.decode(gen[0])\n", + "\n", + "\n", + "\n", + "model_data = []\n", + "\n", + "M = len(q_red)\n", + "\n", + "for i in range(M):\n", + " if i % 10 == 0:\n", + " print(f'processing question {i}')\n", + "\n", + " q = q_red[i]\n", + " a = a_red[i]\n", + " ma = get_query(q)\n", + " model_data.append((q,a,ma))\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'model_data' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/tmp/ipykernel_28855/677667941.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'wb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'model_data' is not defined" + ] + } + ], + "source": [ + "import pickle\n", + "\n", + "fn ='inf3.pkl'\n", + "\n", + "with open(fn,'wb') as f:\n", + " pickle.dump(model_data,f)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "case sensitive literals: True\n", + "test set size: 8981\n", + "correct select col: 86.84%\n", + "correct agg: 89.31%\n", + "correct tbl: 95.71%\n", + "correct conds: 67.38%\n", + "correct query: 57.14%\n" + ] + } + ], + "source": [ + "import re\n", + "fn = 'inf2.pkl'\n", + "\n", + "# defined by WikiSQL\n", + "\n", + "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", + "cond_ops = ['=', '>', '<', 'OP']\n", + "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", + "\n", + "\n", + "with open(fn,'rb') as f:\n", + " test_data = pickle.load(f)\n", + "\n", + "def parse_sel(toks):\n", + " t = 'N/A'\n", + " try:\n", + " t = toks[1].upper()\n", + " if t in agg_ops:\n", + " t = toks[2].upper()\n", + " except:\n", + " pass\n", + " return t\n", + "\n", + "def parse_table(toks):\n", + " s = ''\n", + " for i in range(len(toks)-1):\n", + " if toks[i].lower() == 'from':\n", + " s = toks[i+1]\n", + " return s \n", + "\n", + "def parse_agg(toks):\n", + " i = -1\n", + " try:\n", + " agg = toks[1].upper()\n", + " if agg in agg_ops:\n", + " i = agg_ops.index(agg)\n", + " except:\n", + " pass\n", + " return i\n", + "\n", + "def parse_conds(toks,val_to_lower=False):\n", + " \n", + " conds = set()\n", + " for i in range(len(toks)-3):\n", + " if toks[i].lower() in ['where','and']:\n", + " col = toks[i+1].lower()\n", + " op = toks[i+2]\n", + " op_id = -1\n", + " if op in cond_ops:\n", + " op_id = cond_ops.index(op)\n", + " val = toks[i+3]\n", + " if val_to_lower:\n", + " val = val.lower()\n", + " conds.add((col,op_id,val))\n", + " return conds\n", + " \n", + "cs = 0\n", + "ct = 0\n", + "ca = 0\n", + "cc = 0\n", + "c = 0\n", + "\n", + "case_sens = True\n", + "\n", + "M = len(test_data)\n", + "#print(M)\n", + "\n", + "for i in range(M):\n", + " s = test_data[i][2]\n", + " try:\n", + " m = re.search('A:((.|\\n)*?)END',s)\n", + " s = m.group(1).strip()\n", + " except:\n", + " pass\n", + " \n", + " sa = test_data[i][1][0:-5]\n", + " toks = s.split()\n", + " toksa = sa.split()\n", + " #print(s + \" | \" + sa)\n", + " sel = parse_sel(toks)\n", + " sela = parse_sel(toksa)\n", + " tbl = parse_table(toks)\n", + " tbla = parse_table(toksa)\n", + " agg = parse_agg(toks)\n", + " agga = parse_agg(toksa)\n", + " conds = parse_conds(toks,val_to_lower= not case_sens)\n", + " condsa = parse_conds(toksa,val_to_lower= not case_sens)\n", + " bs = sel==sela\n", + " bt = tbl==tbla\n", + " ba = agg==agga\n", + " bc = conds==condsa\n", + " if bs:\n", + " cs += 1\n", + " if bt:\n", + " ct += 1\n", + " if ba:\n", + " ca += 1\n", + " if bc:\n", + " cc += 1\n", + "\n", + " if bs and bt and ba and bc:\n", + " c += 1\n", + " \n", + " #print(sel + \" | \" + sela + \" | \" + str(sel==sela))\n", + " #print(tbl + \" | \" + tbla + \" | \" + str(tbl==tbla))\n", + " #print(str(agg) + \" | \" + str(agga) + \" | \" + str(agg==agga)) \n", + " #print(str(conds) + \" | \" + str(condsa) + \" | \" + str(conds==condsa))\n", + "\n", + "print(f'case sensitive literals: {case_sens}')\n", + "print(f'test set size: {M}')\n", + "print(f'correct select col: {cs/M*100:.2f}%')\n", + "print(f'correct agg: {ca/M*100:.2f}%')\n", + "print(f'correct tbl: {ct/M*100:.2f}%')\n", + "print(f'correct conds: {cc/M*100:.2f}%')\n", + "print(f'correct query: {c/M*100:.2f}%')\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "6a381460736e8a0eabfb35eafae436ba15c06439de44e28b965ea473bd8dda90" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}