{ "cells": [ { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a9428ee09f334655b6b261d478cbd3d0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import LlamaTokenizer, LlamaForCausalLM\n", "from peft import prepare_model_for_int8_training\n", "tokenizer = LlamaTokenizer.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\")\n", " \n", "tokenizer.pad_token_id = 0\n", "tokenizer.padding_side = 'left'\n", "\n", "model = LlamaForCausalLM.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\",\n", " load_in_8bit=True,\n", " device_map=\"auto\",\n", " torch_dtype=torch.float16\n", ")\n", "\n", "model = prepare_model_for_int8_training(model)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "table: 2-13081928-2\n", "columns: Country,Chart,Period,Peak position,Sales\n", "Q: Name the period for Chart of g-music j-pop/k-pop chart\n", "A: SELECT Period FROM 2-13081928-2 WHERE Chart = 'g-music j-pop/k-pop chart'\n", "\n", "table: 2-13612447-1\n", "columns: Fraction,Ellipsis,Vinculum,Dots,Parentheses\n", "Q: What is the dot value when the ellipsis is 0.012345679…?\n", "A: SELECT Dots FROM 2-13612447-1 WHERE Ellipsis = '0.012345679…'\n", "\n", "table: 1-168274-1\n", "columns: Company,ICB Sector,Ticker symbol,Index weighting (%) at 17 January 2013,Market cap. at April 2013 (€)\n", "Q: Name the total number of index weighting % at 17 january 2013 for bouygues\n", "A: SELECT COUNT Index weighting (%) at 17 January 2013 FROM 1-168274-1 WHERE Company = 'Bouygues'\n", "\n", "table: 2-15826191-2\n", "columns: Rank,Nation,Gold,Silver,Bronze,Total\n", "Q: What is the lowest gold when there are 0 bronze and the total is less than 2, and silver is less than 0?\n", "A: SELECT MIN Gold FROM 2-15826191-2 WHERE Bronze = 0 AND Total < 2 AND Silver < 0\n", "\n", "table: 2-16387912-1\n", "columns: Home team,Home team score,Away team,Away team score,Ground,Date,Time\n", "Q: What is Ground, when Away Team is Sydney?\n", "A: SELECT Ground FROM 2-16387912-1 WHERE Away team = 'sydney'\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "89\n", " 0\n", "count 56355.000000\n", "mean 98.219519\n", "std 21.740325\n", "min 60.000000\n", "25% 84.500000\n", "50% 94.000000\n", "75% 106.000000\n", "max 458.000000\n", "35608\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d548eb2af20f435fa1af81e9045a2d0e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/1000 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random, datasets\n", "d = {'prompt': random.sample(data_red, 1000)}\n", "\n", "tokenizer.pad_token_id = tokenizer.eos_token\n", "\n", "data = datasets.Dataset.from_dict(d)\n", "data = data.map(lambda x:\n", " tokenizer(\n", " x['prompt'],\n", " truncation=True,\n", " max_length=100,\n", " padding=\"max_length\"\n", " ))\n", "\n", "data = data.remove_columns('prompt')\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from peft import LoraConfig, get_peft_model\n", "import transformers\n", "import datasets\n", "\n", "LORA_R = 4\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = .1\n", "CUTOFF_LEN = 256\n", "BATCH = 128\n", "MICRO_BATCH = 4\n", "N_GAS = BATCH//MICRO_BATCH\n", "EPOCHS = 1\n", "LR = 1e-4\n", "\n", "lora_cfg = LoraConfig(\n", " r = LORA_R,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " task_type='CASUAL_LM',\n", " target_modules=['q_proj','v_proj']\n", ")\n", "\n", "model = get_peft_model(model,lora_cfg)\n", "\n", "targs = transformers.TrainingArguments(\n", " per_device_train_batch_size=MICRO_BATCH,\n", " gradient_accumulation_steps=N_GAS,\n", " warmup_steps=0,\n", " num_train_epochs=EPOCHS,\n", " learning_rate=LR,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir='sqllama-out2',\n", " save_total_limit=3,\n", " remove_unused_columns=False\n", ")\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
1 | \n", "2.710700 | \n", "
2 | \n", "2.680400 | \n", "
3 | \n", "2.684500 | \n", "
4 | \n", "2.625600 | \n", "
5 | \n", "2.609600 | \n", "
6 | \n", "2.619100 | \n", "
7 | \n", "2.603800 | \n", "
"
],
"text/plain": [
"