{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kH18jD5cR_Ks", "outputId": "12f11322-9001-418d-ecbd-18031bd34a5d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.9/310.9 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "# !pip install -q accelerate peft bitsandbytes transformers trl faiss-gpu langchain_community wandb flash-attn\n", "!pip install -q accelerate peft bitsandbytes transformers trl datasets\n", "\n", "# flash-attn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cgVNTbBa-D3j" }, "outputs": [], "source": [ "# load the required packages.\n", "import torch\n", "from datasets import load_dataset, Dataset\n", "from peft import LoraConfig, AutoPeftModelForCausalLM, PeftModel, get_peft_model\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, AutoConfig, set_seed\n", "from trl import SFTTrainer\n", "import bitsandbytes as bnb\n", "import transformers\n", "\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import sqlparse\n", "import re\n", "import json\n", "\n", "from huggingface_hub import hf_hub_download\n", "from huggingface_hub import HfFileSystem" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "basaX_55Yf_D" }, "outputs": [], "source": [ "#transformers.logging.set_verbosity_info()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bkkjgGdlrNcq" }, "outputs": [], "source": [ "WRITE_TOKEN = userdata.get('hf_write')\n", "READ_TOKEN = userdata.get('hf_read')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7CKnwlRfZj4V" }, "outputs": [], "source": [ "model_name = \"stabilityai/stable-code-instruct-3b\"\n", "out_name = \"lleticiasilvaa/StableCode-schemaLinking-v0-promptCerto\"\n", "prev_checkpoint = None #\"checkpoint-500\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zs7nCmt-pMC" }, "outputs": [], "source": [ "#!huggingface-cli login" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PZdnxs8k-Cgl" }, "outputs": [], "source": [ "spider_id=\"NESPED-GEN/spider_selector_schemaReduzido\"" ] }, { "cell_type": "markdown", "metadata": { "id": "xT2iRdCN_MFH" }, "source": [ "### Load Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 368, "referenced_widgets": [ "761ff132a1784b8eb0c361a9c3af441d", "5caa7ef820fe42c082e787bc56d10226", "be01605b838048ae85c7167d89cd6516", "a6f9c5edbdbf4d3683928a338105d1c5", "d9ba0122445d41d897ca51246ce19c6f", "af26f382eddf47bcbd822a2e7a393365", "94931961de2b44b48ac40e9414b873f9", "caf381e973a44af181dbc4d42a503f33", "2aedcd2752f64539a5968db5e6a07c8b", "7341de7ff2a84a2c9607c081a5d5b460", "b18c863d1fd945dabc974224fcc9b4b8", "2fa6661d8e264a7c9b6fa0fab5f35a59", "78d6b204ac4948eb9d69a476034bd822", "cc122dcced2c4745b9d75f46bed205ab", "4dd80422743747fc9bc8ba100a9b20fc", "8aa9e08d00014f8283fba7940d1a9c89", "ba3fe8d44af54f808bce4f87a389f619", "5a7e58fdd2b149598e85697723e86a1f", "fd049ba7176c4a0ab248de7562df04d8", "c345e98a60a04d8c9731b21e65b6349c", "119127d225d342dbb9ed2342e58bbfe7", "2119994228b4440f96dfafa974bb1ec7", "1110a23020474da589154b7fdee9f789", "e9106a60cc3a4d2f837ef077786bbef6", "c3b6cfc25e494c87b62c47a8f664ccfd", "4578ca60a3f04d7db9ee6032362bbf0c", "cfd0bf07885a4746b024010509034762", "308bf946c5154224b11bfa43d2d64cbf", "52ed7f1cfb0c48f7b92e6dc74f5b9832", "bf75eb88d5fd483983a6aa18d1c999e6", "82ae814ed43840c8b7cfdfcf412cb5d6", "2bf63132a88c4a348841c9c37bf263da", "3ccdf5730b8f4ff881f6d074b750e838", "1735600c8f984c4584e6127e78059d8e", "ac9aa30dda934b40bdd10e070dd4f6f8", "8ed39247a8d04714b2404ad7ea153492", "f0dbc1b080814f5daff25f099b782be3", "510d4fb426b04d3fa953e071a7a2704c", "88bcfe8ec58b40f9b28b6a3aba4476ae", "0146951df0a44873bd3508b4740da410", "3152b4af351e4c00a0165df32c17e03d", "5cb9deae84dc4036b2549bc1325053c6", "d9f80c97a5f24095819dcaeebff6db41", "6d34d2c194f443f0a803ede212abdf1c", "2c1a3f0f7d674f79aaa013e62016627c", "7cdb9c6ee60f46cb9e7f9f9bb025c20d", "3f835d944f034eadbf3b01466f1ff006", "d7ce27b4a8dc488ba430007b8d9c6d2b", "76bfa38744ba40f2835fc4401053327c", "ac99a2192cfa460fbabc861262f9da36", "b10ac61e56df4250ac46f1d3ab35d637", "28d5f6e29cfd40d3b2f0d332dcdf2c93", "da9ed654dee2428aa9e5de3195c8f3ca", "e17d701ad2de485ba306e1584c56f965", "2283a2506c084205b30fee7b1b2eea9f" ] }, "id": "lLKgY40efdJo", "outputId": "de81874f-5c51-460a-a3f0-51b57e419433" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "761ff132a1784b8eb0c361a9c3af441d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/885 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fa6661d8e264a7c9b6fa0fab5f35a59", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev-00000-of-00001.parquet: 0%| | 0.00/369k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1110a23020474da589154b7fdee9f789", "version_major": 2, "version_minor": 0 }, "text/plain": [ "train-00000-of-00001.parquet: 0%| | 0.00/2.70M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1735600c8f984c4584e6127e78059d8e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating dev split: 0%| | 0/1034 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c1a3f0f7d674f79aaa013e62016627c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/8656 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Dataset({\n", " features: ['db_id', 'question_en', 'query', 'answer', 'hardness', 'query_llm', 'schema_SQLDatabase', 'schema_our', 'schema_dict', 'selector', 'selector_correct', 'schema_SQLDatabase_reduzido', 'schema_SQLDatabase_reduzido_tabelas'],\n", " num_rows: 8656\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spider = load_dataset(spider_id, split=\"train\")\n", "spider" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "po1gNdLjFKdZ" }, "outputs": [], "source": [ "df = spider.to_pandas()" ] }, { "cell_type": "markdown", "metadata": { "id": "ON-hPsEYM1Bu" }, "source": [ "# Load Base Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yEAZpfzlNOHW" }, "outputs": [], "source": [ "def download_checkpoint(adapter_model_id, checkpoint):\n", " fs = HfFileSystem()\n", " for file in fs.ls(f'{adapter_model_id}/{checkpoint}', detail=False):\n", " file_name = file.split(checkpoint)[-1]\n", "\n", " hf_hub_download(repo_id=adapter_model_id, filename=(f'{checkpoint}{file_name}'), local_dir='out')\n", "\n", " for file in fs.ls(f'{adapter_model_id}/logs', detail=False):\n", " file_name = file.split(checkpoint)[-1]\n", "\n", " hf_hub_download(repo_id=adapter_model_id, filename=(f'logs/{file_name.split(\"/\")[-1]}'), local_dir='out')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "enUxjGXEqHxg" }, "outputs": [], "source": [ "# download_checkpoint(out_name, prev_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 337, "referenced_widgets": [ "aa85d2a6ae0540969e6f7c28d8c108dd", "08b65668595341b19a9119cb9358bc94", "8cf951daf94c4b328885027bb5ec575f", "244c7492aaa946d49d44287294d5eef3", "941cfa475c344c5ab46f874ed1271e81", "d2a9aaf865f04c9f89aaf53c2f4817f4", "f4a73a08e2e5493d8765019ef7fe9ae1", "1b081e2e35514ac692e97c7ef49f5c15", "e3a6a04df76b434f8f7c334a87906787", "5278fafb7db44b718c95ebbf5839545d", "a572e1f30aca4cabb4be35844f76a15e", "31807fd45eb341b79d2879b0c313c062", "6ae8c6992a73442da401a29a8fa8c0ef", "5b35f56cd5d34406afdc3d574ec58387", "b6c5bb95124e4f32ab37317e85cad585", "75bfd69219244a2c99d58f0790c589fe", "5e4ecd7200a94b3eb77c63bf82bae1ec", "e9b7088de82f4c07b71e8f26a49ec728", "11dbca784ce440dfbc1d3ce3fa8a517d", "a070f23095ab46aab1eb4a98c6a4ba42", "c1521cb4455d47dcaf61f9eae2a593c0", "c8fa917147ee4d5481bee7b104e5682b", "128e5c86402f44b29debd5e3a137d8f0", "2ad7b6745733469cada466947fbe2b30", "253a2834a03f4f7ea917758f16eb95dc", "1e0a2bcc4bc140c69ef1a1bbd400264c", "fb3d4da99d404b4f9c13d99fec9f5b6f", "2b36bb42d97c480aa4618e98b4c9074b", "6e0964a247cf4dceb509262497bbd29d", "980ba74972494157957e0bb7912f1b94", "2ca4abd2471a46aab7e5ab77f392733f", "dce18a4596e341dda8c433eed90e36d6", "00ac9173494d4cbdbbe454a426866d82", "16db42ce144d4ffc955daed3b5bbe041", "8473a2552add4f7cb9b3979716744c8c", "516738ab026f4e6c94bc1fe7fd76ccad", "ed30649f8d4948259c517fcc12814137", "8ef49f463494455dba39b14edd3855b7", "153f3ef7108145da89a4866a5adadf69", "95662605b76548a5a2fa3787967e59c8", "1efff4efc9e044c08a5f7e5c27c90f73", "12085e5e5abe4a2f811157897fef2eab", "9e9911d2e4b54305b0fdde9247051907", "82fedc6913f5492eb6222bb6c0791a7f", "d06f4b2daa1542fdb9d154f174051f70", "3b3c8c5c6e62489c877bd24eb5be80d5", "bf85b7c933ea4290bed1dccdee9bdacc", "c8608a2ee3e244138fb2fc3c1b31d749", "73fe90983ec048ccb1643625236a3f8a", "9650e5cb692440f4877c6b316f415421", "3bd5019557624d3dba2607132abf8a3c", "40b7e7191a7b4bd1a9d465c0ac58b5c0", "04570ba166bc4d3a9a56eab772fdef91", "fdc7d792929b480e8276b3d73cc3591a", "55c89494025645b986fc744043c3b068", "cd64fd17458d42d6ad0645856d702cb5", "945a7818f30748fabecfd0544d03b719", "930731d7920646e2a1bedcbcfce1143b", "95d812217d7545b3a1e2019fc03d21b2", "eb7d62e0331e4c2ba691c4aa13cfdfc9", "6427121d1c9b46babc39691aa52dafd3", "5242e3b6ef4740779d58974b888dd8c8", "531e0f2b93e94e51a742c140ef8beae0", "c90e82018d594b46a56b2efe790f00c1", "5122b36e79d04afe9d3d003c4769fb1d", "16e58a2fc1e94688bbca8641bd646035", "02eadbcf3ed643229c8f3daa3914d027", "db5762a6f1354e8fadf6e77924b6cdd5", "fe553597f802443c8557b118a22db2dd", "b7924ce0a6c141509f0ec5b2f641c425", "ebdea15a36044414af0149c3affb51f8", "713657575e0642b28c9ce2eb3b9405c3", "e889619634b749ef87383b6fef67111c", "344b6b7ff19a4588beb10272b84875ef", "10e326a25d794f489f620b9320b915ea", "ee5e33c1b7844cf080cf916616c3a4b5", "3eecc02c90594619a86e37d016756da6", "09872654815e4e23a2ef51c0fd6f67cc", "86f160c486504079a9963e74a0c3a84e", "3e6b77ebec0f45878c932dd15918eea7", "2ab61c9e867748b5ae3b82d13923e330", "89943a38d7794efb8c5910db1c13ad88", "4f63a20edb554c7088421e83d132a25f", "ad3fdf54970f4dfab121a08f1647aed2", "0ff9916bac384664a120255a98ca1da7", "97b5c21d7cf94b4dba442b15457a8013", "020349c0de2f40578ba34d39cf9c47ee", "5389155d53ae43a4b7989a94f3d41de3", "6a290e922c254a7faed5b89435a9b339", "880c8793631f46f1b0a36ad912271687", "6c435523bdb24b24b17e17c86fe61176", "022f018ba1b44f68a8c16826d393c303", "6f8e79f15f304dfbabd0a42b7c8cbf66", "e038423a26a24aa8beb7100c0f18f9df", "ccfca5ddfadc4e8a9b6ee34cd78ae1ed", "56225c71e90345a192f5a7076e42c523", "9c7730b4018e466390befa6fb5858490", "7112ef7631ce45afbb58bd298057e802", "a267c8d28b454dee974e431dfb849f95", "fce0b64cf8094dd8ae8606c07076cf41", "5e4ae768b6294386af6f955edd55020a", "04ae5f38b9e4421ea6312a8465e73a4d", "545ee6e11d5344cabaf7438251a01aa2", "794be39f79b041caaeeffc7a50e271d2", "a9b49b16bacd4fc7b9a18a8d64a4d90f", "9dbef0b16729429595a1ebbdd4b919be", "7ba3484aa6384cfb8a48f3bb2e5498e2", "e2980e9183dd4ca0bfc0f5a56c343dbc", "edc9811bd0a4455ba76670423b8acf1c", "6bf3d52364c94dd2b8a53d0ce54886ce" ] }, "id": "M7DoqQMlM_nW", "outputId": "e9981ff1-696c-43f4-f4d8-27329cb63f08" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa85d2a6ae0540969e6f7c28d8c108dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/9.35k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31807fd45eb341b79d2879b0c313c062", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/2.12M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "128e5c86402f44b29debd5e3a137d8f0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/587 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "16db42ce144d4ffc955daed3b5bbe041", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/738 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d06f4b2daa1542fdb9d154f174051f70", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors.index.json: 0%| | 0.00/29.4k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cd64fd17458d42d6ad0645856d702cb5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading shards: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "02eadbcf3ed643229c8f3daa3914d027", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00001-of-00002.safetensors: 0%| | 0.00/4.98G [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "09872654815e4e23a2ef51c0fd6f67cc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00002-of-00002.safetensors: 0%| | 0.00/610M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6a290e922c254a7faed5b89435a9b339", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fce0b64cf8094dd8ae8606c07076cf41", "version_major": 2, "version_minor": 0 }, "text/plain": [ "generation_config.json: 0%| | 0.00/132 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "seed=14\n", "\n", "\n", "if (prev_checkpoint != None):\n", " try:\n", " download_checkpoint(out_name, prev_checkpoint)\n", " except:\n", " pass\n", "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=READ_TOKEN, map_device=\"auto\", add_eos_token=True, use_fast=True)\n", "\n", "new_tokens = {'additional_special_tokens': ['[SQL]','[/SQL]', '[QUESTION]','[/QUESTION]']}\n", "#adicionar tokens especiais:\n", "# if (prev_checkpoint == None):\n", "# tokenizer.add_special_tokens(new_tokens)\n", "\n", "\n", "if torch.cuda.is_bf16_supported():\n", " compute_dtype = torch.bfloat16\n", " attn_implementation = 'flash_attention_2'\n", "else:\n", " compute_dtype = torch.float16\n", " attn_implementation = 'sdpa'\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"\n", "\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=compute_dtype,\n", " bnb_4bit_use_double_quant=False,\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=compute_dtype,\n", " device_map=\"auto\",\n", " quantization_config=bnb_config,\n", "\n", " trust_remote_code=True,\n", " token=READ_TOKEN,\n", " # attn_implementation=attn_implementation\n", ")\n", "\n", "# se adicionar special_tokens tem que fazer resize do tokenizer:\n", "# model.resize_token_embeddings(len(tokenizer))\n", "\n", "## model.resize_token_embeddings(max(len(tokenizer), model.config.vocab_size))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fLuqzhSJBvi8", "outputId": "65293049-fcc8-491b-aaca-579bca0686c7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n", "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n" ] }, { "data": { "text/plain": [ "Embedding(32004, 2048)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# se adicionar special_tokens tem que fazer resize do tokenizer:\n", "#model.resize_token_embeddings(len(tokenizer))\n", "\n", "#model.resize_token_embeddings(max(len(tokenizer), model.config.vocab_size))" ] }, { "cell_type": "markdown", "metadata": { "id": "_I7-bFfm5gqS" }, "source": [ "#### Chat Template - Gerar SQL" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cYVA3Q7ZCzHi" }, "outputs": [], "source": [ "# tokenizer.chat_template = \"\"\"\n", "# {% if messages[0]['role'] == 'system' %}\n", "# {% set loop_messages = messages[1:] %}\n", "# {% set system_message = messages[0]['content'] %}\n", "# {% else %}\n", "# {% set loop_messages = messages %}\n", "# {% set system_message = 'Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.' %}\n", "# {% endif %}\n", "# {{ '# <|system|>/n/' + system_message + '/n//n/' }}\n", "# {% if messages|selectattr(\"role\", \"equalto\", \"example\")|list %}\n", "# Below are some examples of question and their corresponding SQL queries:/n//n/\n", "# {% else %}\n", "# /n/\n", "# {% endif %}\n", "# {% for message in loop_messages %}\n", "# {% if message['role'] == 'example' %}\n", "# {{ message['content'] }}/n//n/\n", "# {% elif message['role'] == 'schema' %}\n", "# # <|schema|>/n/The query will run on a database with the following schema:/n/{{ message['content'] }}/n//n/\n", "# {% elif message['role'] == 'user' %}\n", "# # <|user|>/n/[QUESTION]{{ message['content'] }}[/QUESTION]/n//n/\n", "# {% elif message['role'] == 'assistant' %}\n", "# # <|assistant|>/n/[SQL]{{ message['content'] }}[/SQL]\n", "# {% endif %}\n", "# {% endfor %}\n", "# {% if add_generation_prompt %}\n", "# # <|assistant|>/n/[SQL]\n", "# {% endif %}\n", "# \"\"\".replace(\"\\n\",\"\").replace(\" \", \"\").replace(\"/n/\", \"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IE87Vlt1hrje" }, "outputs": [], "source": [ "import re\n", "\n", "def replace_alias_with_table(query):\n", " # Expressão regular para encontrar tabelas com alias, capturando o nome da tabela e o alias\n", " alias_pattern = re.compile(r'(\\bFROM\\b|\\bJOIN\\b)\\s+(\\w+)\\s+AS\\s+(\\w+)', re.IGNORECASE)\n", "\n", " # Substituições de aliases encontrados no padrão\n", " aliases = {match.group(3): match.group(2) for match in alias_pattern.finditer(query)}\n", "\n", " # Substituir cada alias pelo nome da tabela correspondente\n", " for alias, table in aliases.items():\n", " query = re.sub(r'\\b' + alias + r'\\b', table, query)\n", "\n", " # Remover 'AS' e alias das cláusulas 'FROM' e 'JOIN'\n", " query = re.sub(r'\\bAS\\s+\\w+', '', query, flags=re.IGNORECASE)\n", " return query" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6RDDdVgP5gqT" }, "outputs": [], "source": [ "def to_sql(query):\n", " return sqlparse.format(replace_alias_with_table(query), reindent=True, keyword_case='upper')\n", "\n", "def apply_template(row, tokenizer=tokenizer, n_examplea=0):\n", " question = row['question_en']\n", " schema = row['schema_SQLDatabase_reduzido']\n", " sql = to_sql(row['query'])\n", "\n", " system = \"Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.\"\n", "\n", " chat = [\n", " {'role': 'system', 'content': system},\n", " {'role': 'user', 'content': f\"# Schema:\\n```sql\\n{schema}\\n```\\n\\n# Question: {question}\"},\n", " {'role': 'assistant', 'content': f\"```sql\\n{sql}\\n```\\n\"}\n", " ]\n", "\n", " # chat = [\n", " # {\"role\": \"schema\", \"content\": schema},\n", " # {\"role\": \"user\", \"content\": question},\n", " # {\"role\": \"assistant\", \"content\": sql},\n", " # ]\n", "\n", " row['text'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)\n", "\n", " return row" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2trHEegL5gqU" }, "outputs": [], "source": [ "# spider_chain = json.load(open(\"/content/drive/Shareddrives/LLMs/Datasets/spider/spider_chain_of_thought.json\", \"r\"))\n", "# bird_chain = json.load(open(\"/content/drive/Shareddrives/LLMs/Datasets/bird/bird_chain_of_thought.json\", \"r\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N4jWrC7s5gqU" }, "outputs": [], "source": [ "# df['CoT'] = spider_chain + bird_chain" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bTF0pBsw5gqU" }, "outputs": [], "source": [ "df = df.apply(apply_template, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L4tjUv7o5gqV" }, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DfJvLaGR5gqV" }, "outputs": [], "source": [ "# df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.encode(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vJseOHIu5gqW" }, "outputs": [], "source": [ "# import seaborn as sns\n", "# sns.histplot(df['n_tokens'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PIvSnr6Y5gqW", "outputId": "2d3c87b0-eafc-487a-d563-8dc708283d32" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.<|im_end|>\n", "<|im_start|>user\n", "# Schema:\n", "```sql\n", "CREATE TABLE city (\n", " City_ID INT,\n", " Status TEXT,\n", " Population REAL,\n", " PRIMARY KEY (City_ID)\n", ");\n", "```\n", "\n", "# Question: Show the status shared by cities with population bigger than 1500 and smaller than 500.<|im_end|>\n", "<|im_start|>assistant\n", "```sql\n", "SELECT Status\n", "FROM city\n", "WHERE Population > 1500 INTERSECT\n", " SELECT Status\n", " FROM city WHERE Population < 500\n", "```\n", "<|im_end|>\n", "\n" ] } ], "source": [ "print(df['text'][df.index[50]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "roZzKNOj5gqW" }, "outputs": [], "source": [ "_df = pd.DataFrame(columns=['text'])\n", "_df['text'] = df.sample(frac=1, random_state=14).reset_index(drop=True)['text']\n", "_df = Dataset.from_pandas(_df)\n", "_df = _df.train_test_split(test_size=0.01, shuffle=True, seed=14)\n", "train_dataset, valid_dataset = _df[\"train\"], _df[\"test\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "b6mjOblXeMup" }, "source": [ "#### Chat Template - Schema Linking" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "opbdu1g6eMuq" }, "outputs": [], "source": [ "def apply_template(row, tokenizer=tokenizer, n_examplea=0):\n", " question = row['question_en']\n", " schema = row['schema_SQLDatabase']\n", " schema_linking = row['selector_correct']\n", "\n", " system = \"Given a user question and the schema of a database, your task is to generate an JSON with the the names of tables and columns of the schema that the question is referring to.\"\n", "\n", " chat = [\n", " {'role': 'system', 'content': system},\n", " {'role': 'user', 'content': f\"# Schema:\\n```sql\\n{schema}\\n```\\n\\n# Question: {question}\"},\n", " {'role': 'assistant', 'content': f\"```json\\n{schema_linking}\\n```\"}\n", " ]\n", "\n", " row['text'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)\n", "\n", " return row" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_HHKJ9VeMur" }, "outputs": [], "source": [ "df = df.apply(apply_template, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oF0-4l8DeMus" }, "outputs": [], "source": [ "# df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.encode(x)))\n", "# import seaborn as sns\n", "# sns.histplot(df['n_tokens'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EFY_pTEteMut", "outputId": "02e8db81-66e6-4fc0-845d-8e57f09359ba" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "Given a user question and the schema of a database, your task is to generate an JSON with the the names of tables and columns of the schema that the question is referring to.<|im_end|>\n", "<|im_start|>user\n", "# Schema:\n", "```sql\n", "CREATE TABLE Addresses (\n", " address_id INTEGER,\n", " line_1 VARCHAR(80),\n", " line_2 VARCHAR(80),\n", " city VARCHAR(50),\n", " zip_postcode CHAR(20),\n", " state_province_county VARCHAR(50),\n", " country VARCHAR(50),\n", " PRIMARY KEY (address_id)\n", ");\n", "\n", "CREATE TABLE People (\n", " person_id INTEGER,\n", " first_name VARCHAR(255),\n", " middle_name VARCHAR(255),\n", " last_name VARCHAR(255),\n", " cell_mobile_number VARCHAR(40),\n", " email_address VARCHAR(40),\n", " login_name VARCHAR(40),\n", " password VARCHAR(40),\n", " PRIMARY KEY (person_id)\n", ");\n", "\n", "CREATE TABLE Students (\n", " student_id INTEGER,\n", " student_details VARCHAR(255),\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (student_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Courses (\n", " course_id VARCHAR(100),\n", " course_name VARCHAR(120),\n", " course_description VARCHAR(255),\n", " other_details VARCHAR(255),\n", " PRIMARY KEY (course_id)\n", ");\n", "\n", "CREATE TABLE People_Addresses (\n", " person_address_id INTEGER,\n", " person_id INTEGER,\n", " address_id INTEGER,\n", " date_from DATETIME,\n", " date_to DATETIME,\n", " PRIMARY KEY (person_address_id),\n", " FOREIGN KEY (address_id) REFERENCES Addresses(address_id),\n", " FOREIGN KEY (person_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Student_Course_Registrations (\n", " student_id INTEGER,\n", " course_id INTEGER,\n", " registration_date DATETIME,\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (course_id) REFERENCES Courses(course_id),\n", " FOREIGN KEY (student_id) REFERENCES Students(student_id)\n", ");\n", "\n", "CREATE TABLE Student_Course_Attendance (\n", " student_id INTEGER,\n", " course_id INTEGER,\n", " date_of_attendance DATETIME,\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (student_id) REFERENCES Student_Course_Registrations(student_id),\n", " FOREIGN KEY (course_id) REFERENCES Student_Course_Registrations(course_id)\n", ");\n", "\n", "CREATE TABLE Candidates (\n", " candidate_id INTEGER,\n", " candidate_details VARCHAR(255),\n", " PRIMARY KEY (candidate_id),\n", " FOREIGN KEY (candidate_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Candidate_Assessments (\n", " candidate_id INTEGER,\n", " qualification CHAR(15),\n", " assessment_date DATETIME,\n", " asessment_outcome_code CHAR(15),\n", " PRIMARY KEY (candidate_id),\n", " FOREIGN KEY (candidate_id) REFERENCES Candidates(candidate_id)\n", ");\n", "```\n", "\n", "# Question: How many students are attending English courses?<|im_end|>\n", "<|im_start|>assistant\n", "```json\n", "{\n", " 'Courses': ['course_id', 'course_name'],\n", " 'Student_Course_Attendance': ['student_id', 'course_id']\n", "}\n", "```<|im_end|>\n", "\n" ] } ], "source": [ "print(df['text'][df.index[70]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "puYY-BqFeMuu" }, "outputs": [], "source": [ "_df = pd.DataFrame(columns=['text'])\n", "_df['text'] = df.sample(frac=1, random_state=14).reset_index(drop=True)['text']\n", "_df = Dataset.from_pandas(_df)\n", "_df = _df.train_test_split(test_size=0.01, shuffle=True, seed=14)\n", "train_dataset, valid_dataset = _df[\"train\"], _df[\"test\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "DWpXeuO_KlLS" }, "source": [ "### Finetuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0oVpZDj1AXY9" }, "outputs": [], "source": [ "from huggingface_hub import login, create_repo\n", "from google.colab import userdata\n", "import wandb\n", "import os\n", "\n", "#token = userdata.get('hf_write')\n", "token = WRITE_TOKEN\n", "login(token=token)\n", "set_seed(1234)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KRhO7UJ-Q4Y8" }, "outputs": [], "source": [ "def find_all_linear_names(model, new_tokens=False):\n", " lora_module_names = set()\n", " for name, module in model.named_modules():\n", " if isinstance(module, bnb.nn.Linear4bit) or isinstance(module, bnb.nn.Linear8bitLt):\n", " names = name.split(\".\")\n", " lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n", " if(new_tokens):\n", " lora_module_names.add(\"lm_head\")\n", " return list(lora_module_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "L0qqP5Y9PtRh", "outputId": "53602ce9-3893-4f04-e5fe-0c57500d949d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 7 modules to quantize: ['gate_proj', 'down_proj', 'k_proj', 'up_proj', 'o_proj', 'q_proj', 'v_proj']\n" ] } ], "source": [ "modules = find_all_linear_names(model)\n", "print(f\"Found {len(modules)} modules to quantize: {modules}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uFUnJrbjPwAT" }, "outputs": [], "source": [ "peft_config = LoraConfig(\n", " lora_alpha=128, #primeira versão = 16\n", " lora_dropout=0.1,\n", " r=64,\n", " # bias=\"none\",\n", " # task_type=\"CAUSAL_LM\",\n", " target_modules=modules,\n", " # modules_to_save=[\"embed_tokens\"], #quando adicionar tokens speciais\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "buh0o2P2jwbx" }, "outputs": [], "source": [ "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true, "base_uri": "https://localhost:8080/", "height": 920, "referenced_widgets": [ "d5fc9c76496d4fe8b3cf37520aa4a590", "d053ca741c054d2fa56c674d53c8b28e", "2d76bbe56cb74b3982f274650f2f3999", "5c5e9e32dc3148008a10e4ea1401abaf", "633a02f02dcf4a83b0a1d63431d27d6a", "c7267a6dfbf9475dbd58f4075404ccd6", "a5804014e5754948bd0f4c12f38724e4", "05b282e912384e15aef6e856c030cdcb", "4cfe15eb4804413ebf77386faaeb4f5a", "60b66e615fb643b880613d194bb2000f", "dc8e6aa81bcb483483b1fd1a42f7fc4d", "c748cf11cfa349cfaf344acb1a47a96a", "7e32d4b7d6a8466cbfcf350c6b91db2b", "fb2f6ac5b36b4d9f9a2a0999a68b85e0", "4964568991d943bb8a3298a755abd46b", "48c2f3f828a0482eb940a333b5e5bcfc", "50808f437a6b4721a0cd8a980614d963", "675f318a52a0457c944ef6c7d89635de", "54d2cee22c9e42a69fbad2ebd9f1111c", "34d7463529a045a984e1a66e342baaa2", "292356f9b20d4430853635c87c3dcaf9", "78cd64ba31114ab28c43733da790a4bd" ] }, "id": "9bD7ea0F-GQn", "outputId": "deebdbc1-7ae6-4cd7-8d88-5450d59fe3b1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '0.13.0'.\n", "\n", "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5fc9c76496d4fe8b3cf37520aa4a590", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/8569 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c748cf11cfa349cfaf344acb1a47a96a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/87 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
250 | \n", "0.277400 | \n", "0.152294 | \n", "
500 | \n", "0.113000 | \n", "0.119896 | \n", "
750 | \n", "0.097600 | \n", "0.108571 | \n", "
1000 | \n", "0.088800 | \n", "0.105137 | \n", "
"
],
"text/plain": [
"