{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 22207, "status": "ok", "timestamp": 1731936919321, "user": { "displayName": "HighSchoolMinas - MinasCoders", "userId": "06806689059039259777" }, "user_tz": 180 }, "id": "kH18jD5cR_Ks", "outputId": "924139c0-7305-4d4f-e4c2-e9cbf88b69cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.9/310.9 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "# !pip install -q accelerate peft bitsandbytes transformers trl faiss-gpu langchain_community wandb flash-attn\n", "!pip install -q accelerate peft bitsandbytes transformers trl datasets\n", "\n", "# flash-attn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cgVNTbBa-D3j" }, "outputs": [], "source": [ "# load the required packages.\n", "import torch\n", "from datasets import load_dataset, Dataset\n", "from peft import LoraConfig, AutoPeftModelForCausalLM, PeftModel, get_peft_model\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, AutoConfig, set_seed\n", "from trl import SFTTrainer\n", "import bitsandbytes as bnb\n", "import transformers\n", "\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import sqlparse\n", "import re\n", "import json\n", "\n", "from huggingface_hub import hf_hub_download\n", "from huggingface_hub import HfFileSystem" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "basaX_55Yf_D" }, "outputs": [], "source": [ "#transformers.logging.set_verbosity_info()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bkkjgGdlrNcq" }, "outputs": [], "source": [ "WRITE_TOKEN = userdata.get('hf_write')\n", "READ_TOKEN = userdata.get('hf_read')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7CKnwlRfZj4V" }, "outputs": [], "source": [ "model_name = \"stabilityai/stable-code-instruct-3b\"\n", "out_name = \"lleticiasilvaa/StableCode-text2SQL-schemaReduzido\"\n", "prev_checkpoint = None #\"checkpoint-500\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zs7nCmt-pMC" }, "outputs": [], "source": [ "#!huggingface-cli login" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PZdnxs8k-Cgl" }, "outputs": [], "source": [ "spider_id=\"NESPED-GEN/spider_selector_schemaReduzido\"" ] }, { "cell_type": "markdown", "metadata": { "id": "xT2iRdCN_MFH" }, "source": [ "### Load Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 264, "referenced_widgets": [ "da5c091084d144b5bcf1c09c0dcde4f0", "2259a5d56ca14424b4b83e4814d8ed5c", "4f827e7e9c8143a9a66fbc5a337453c6", "41db38eeaa224263a915868a767f7a7f", "8a3c6a8ba4f4442487b588008d833250", "4e7d6f28d53f4f5a9602bbeb97b18977", "4eecfd540245491baa0eef44e55f8c5c", "35a226337cd0484c9f1d16c2207c24a9", "3f953a090a6e40cc9c170495d55477da", "55282b1ef4634746853550df79492781", "0b2336d126e84bc095a24971608f4187", "c13e11608e7a461a942a575b54e764c7", "d0d513bf8cad48b3bdd9870f39f36058", "f020a3012eb2412683960caa0b0e8d22", "a34b8d9700864ea3bce651705bfaebf5", "3a24ff07ca894ac68f0f35bfeef09c31", "bb483effb81248b99911546ddd11e0d7", "918dfb6e0be647b896e4a707fe7a2c7c", "b566d6281d4140efb609008c5cf16a34", "6ef28a18e0a14ce59074c76b0fa46cea", "61e892dcf80f40cd9e1c895c678a9c9d", "553fad3cddd344dea8d52f3ab2a49f85", "e4e00110363d4dddb9ba996c038743da", "6f0b37bffb4e4fed990a0a9bf8b58829", "714ceded01bf48118a6224048571a8e5", "4b044760037a454192ed3220e1b92845", "6347edc1ac59495b94700bce2cc8d802", "e982e5da491843a5aecf0778f9f86552", "d52845ae280040059faf6f46b64acbaa", "42a6626eea084c44a3aa409e45d23592", "1a99c1df706646d09b604d130cface07", "82365e54a4ad44248642db67f854f9a7", "cd4305d9e72c4350920592b70a5032f9", "6769642dae9146ba87f576afa219f23b", "ce39f86d765f4bf4a6be248ce6a91b53", "7470bd376b2542f39be29048307b2f4b", "78752f93c4f44cb2a3d4ba9f1a0f515c", "46163d43293a449e969871a3380d6662", "f2218f51cc49430b86f8447f143f14ee", "90d871c75b4d409682fb7df71431f948", "e54f383092c54422bba435911e5bc132", "090c233031144c32a0c558f1a0f4d856", "af07adf75b694c27a4ff47d2b94ada12", "2475467d81314b969eabe7cbaa688215", "71f55dee7ad244db921fa67a1ac3a9b9", "e59c79d891db4489a03ed61569fc4c76", "0f9c5ff272944cdc9b91a4c4a50d5d61", "b9a2e6536ed14af0b1172057c18fa0ae", "151e358e1c18478b8563e7b170512865", "447ed2f1f5ad4800973ca622ed1d0d17", "e8d33ddec7f949608ca80670656fe9de", "92ebec4232bb49b3a2d89142596613d7", "c60acfa5afef4c21b4ee82d17e354f3c", "d9056a6ca2484d98b03b0fa9d8631158", "10d1a188428248c1bc62ee11fe8131ce" ] }, "executionInfo": { "elapsed": 10362, "status": "ok", "timestamp": 1731936986600, "user": { "displayName": "HighSchoolMinas - MinasCoders", "userId": "06806689059039259777" }, "user_tz": 180 }, "id": "lLKgY40efdJo", "outputId": "b576abed-3fae-4e52-d094-ab270c9f8541" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da5c091084d144b5bcf1c09c0dcde4f0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/885 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c13e11608e7a461a942a575b54e764c7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev-00000-of-00001.parquet: 0%| | 0.00/369k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4e00110363d4dddb9ba996c038743da", "version_major": 2, "version_minor": 0 }, "text/plain": [ "train-00000-of-00001.parquet: 0%| | 0.00/2.70M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6769642dae9146ba87f576afa219f23b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating dev split: 0%| | 0/1034 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "71f55dee7ad244db921fa67a1ac3a9b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/8656 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "Dataset({\n", " features: ['db_id', 'question_en', 'query', 'answer', 'hardness', 'query_llm', 'schema_SQLDatabase', 'schema_our', 'schema_dict', 'selector', 'selector_correct', 'schema_SQLDatabase_reduzido', 'schema_SQLDatabase_reduzido_tabelas'],\n", " num_rows: 8656\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spider = load_dataset(spider_id, split=\"train\")\n", "spider" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "po1gNdLjFKdZ" }, "outputs": [], "source": [ "df = spider.to_pandas()" ] }, { "cell_type": "markdown", "metadata": { "id": "ON-hPsEYM1Bu" }, "source": [ "# Load Base Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yEAZpfzlNOHW" }, "outputs": [], "source": [ "def download_checkpoint(adapter_model_id, checkpoint):\n", " fs = HfFileSystem()\n", " for file in fs.ls(f'{adapter_model_id}/{checkpoint}', detail=False):\n", " file_name = file.split(checkpoint)[-1]\n", "\n", " hf_hub_download(repo_id=adapter_model_id, filename=(f'{checkpoint}{file_name}'), local_dir='out')\n", "\n", " for file in fs.ls(f'{adapter_model_id}/logs', detail=False):\n", " file_name = file.split(checkpoint)[-1]\n", "\n", " hf_hub_download(repo_id=adapter_model_id, filename=(f'logs/{file_name.split(\"/\")[-1]}'), local_dir='out')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "enUxjGXEqHxg" }, "outputs": [], "source": [ "# download_checkpoint(out_name, prev_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 337, "referenced_widgets": [ "405c36be5d8c4c5f9ac5e133f5e0e5a5", "ab8767a760b84eae9bc1e9e4919bf5c5", "aae55e8e55fb4da9aa7474c04d8b53d5", "7aa964bc37d44b3582e666ffb41bdf44", "9d906b6bb69b47f4a903ed6d771f4865", "3800fc4263a1406683ad5ae4f161cf10", "aea483f207654f5b8fbe48e9a3dd5c79", "fc82986f744a4d6dbd55c228daf7b72c", "6fdf5ae86e3a41fe9096f70111f41c24", "ea99f1883f4847908a0122465b68152e", "ac516e6a38684139a7ba1a8e444934a3", "336fa077ec1446d4a4054617ab5974fc", "e6972d36eac3423d8f8c3428a8eed64c", "aad020e4ef7c445fac60cbf6febca1e4", "54b98753fd304db98d8185061a6d38d0", "a17345800475437faf657e7546367762", "32429ea22aa04fa9931c1957e6a8efb4", "09356a5fd3d249ac8b070e543a5fad06", "50384545e13a486096d6b57d20758f3b", "b361f9b783d846ccafe3a2bccd75732d", "4eb34f295dc84c6a87ce39928c8efcde", "4aef525021654be78c39be799b54421d", "675baeb5cb684edf8c68f8c4b312d0cd", "77e5f0711ee2407b8661e6b19527053e", "7e055162bde6458aa18d8f1507fa35c6", "6951a45f568148d097159fa4e137cc92", "3d3c96b3de5245de84921649ad906e99", "f881fc2f9fe244188ee6c707035e2d8d", "6b8a7a581e2943a3baeea5949f6aaee8", "8fa4649b69e1495f8ae427daecac9f56", "2e250483bf93480ba4d74da7200e6d6f", "2d2fd7f808d94ba38c62c1abf8486cf5", "b433e47ec9aa4532be882b6623efd63e", "55d8da75aae24d98802b71dece0e57bb", "3a636968a51b484d9cf636a7eb39ed52", "7639596ffcfa4f898cb2f231843725d7", "e7a5e53b243f450fad00bb298a62963b", "d2c3e00478b04c458dc4270bf7a7c41d", "f3061bbb84514ef4a6061f56d4b3407d", "278bfd17a5544273a4cab73db076266f", "cc7ab09789534966b058d29e4a2e193c", "202ffde4588345ea969d9004fb9ff6db", "839ef4239562451b9f7578a3716b27f5", "f58ff9cba4f44fb8922bd701ec33b4d9", "8ff5553c1d5f4b46bc035d317c995076", "9a068e7884414287baa24faa945f1641", "c915928f29e34c31bbd8197dbd519cb5", "4b0d0c255fd44f05b91be3d5443535d1", "ea1c82848f394c328c9a11046063f3de", "67324d9e136c4c86bfb80f0656300ac0", "4268e90eefdb46a1b350b0a4ded60333", "01dd3f6dd83d4b348c7848b94f6f9969", "0092acef600d48c48f331e21cb8ff9ca", "e0e09a805b1540b09e4729f71d19e69c", "4b1d99d7adc44e1a92533c2564820d4b", "40f96c493ea2477992c0f977d540f5a8", "42a87cd354e44bfb98500392425b5310", "015af8dc4dfb4ab9a1a3f04ec0b7f398", "2f3b3ac67d6a42aca0963d343a246249", "1d5d306462a544ae844526c38db22d53", "75eda375d665450a816b4c6a97923157", "73ea824dc10f49b19c61fe43d0a165b0", "23e3268d574645038523ad63597b88d7", "967011e6b9584abd98a31dfe87fa6095", "06de9befc68046c497660894f1a90acc", "402c0ccdd58240edaf4947a2dae19928", "d36e1146d77b4105b693556aae66f811", "6e3a0d789f7a45fbaee17e36b3874ef6", "b62767fce186437ba57e1f4253b2037d", "b629c1d8af5b4d4c850b5168b2c9fe7f", "141e0a081a3d4738bf547a7a411809ff", "711e2c43a4f549fd8bbd4892736ee0d2", "9a05b1ec209a4e7ead6c0d37cd3b4424", "cf0ad160f50048cc906e25e1c2c8c70a", "9f5870390c714ebbbcfc42ed7c4bf90e", "0451ff0a4e524c31bb18e7c4ce7cf890", "689fdf70288e484badda81568ea591a8", "e101711685144e8aba3bcd3589fdadf0", "fcbff61595e3445ba5e899d586f35640", "327986ff08f84a6cb8bbcab089c136d5", "3f8eb7b1b6c3459fb89d5af89a84294d", "3ec2d2df48334f619eb42d0da989d749", "37b1562aa0b749d3b80ec58eba216cd9", "d15b5eab95194d928a0aeff39c7e0df2", "b30859c01ef34e2c8452451bc35bc65e", "f3027187a89945c7891a58d748671c43", "3297f623e2604099aec0720fd208b32a", "2c91c90594494f2ba978c7f531746d40", "3d39bded9c394adaa035c883b7b26538", "142b2cc271084232b27a62b132b1cea5", "72a2e266aa484476a46797f06431a722", "da9bab0916d34e69a697a3e05ee03f26", "295d9f93621d4f21abdce758c7298252", "7c3005b5bffe4707a397d863edf2fa54", "562d34aa9458447284de34168a7553f5", "ee08faa9687347c6a786669fb35e0bbe", "a16c41376d7c40cfa261047e3e705330", "fd7e3eda9c9d4495b39989b32bf60e1b", "a841a9c59a694684ab05520dc75d5891", "771a2968cb03418099f6eb0804c11669", "4fadf3ccdeb94445add6098ad5243e8a", "19cbdb81690f4990be64624b8f70fe7f", "625f4831a6774baea77ecc7d64b36279", "b820dfae028e41d3b34adfbbead5cb4f", "aaf2b0b2140a429d9d97c32de3d097a2", "3d903f4ce76c4fd0a205b3ff2b0f5814", "3c5853848d4541d4b1b14d368fb4093c", "4d171ed5611a4c15bdfccfabc883a7b6", "9ef0fbd7e7414305abbdce94bfeb1693", "37391a58d6444cf4bb5d4e529750d552" ] }, "executionInfo": { "elapsed": 173065, "status": "ok", "timestamp": 1731937187115, "user": { "displayName": "HighSchoolMinas - MinasCoders", "userId": "06806689059039259777" }, "user_tz": 180 }, "id": "M7DoqQMlM_nW", "outputId": "918430ce-7047-4032-9c1d-cdee271db905" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "405c36be5d8c4c5f9ac5e133f5e0e5a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/9.35k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "336fa077ec1446d4a4054617ab5974fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/2.12M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "675baeb5cb684edf8c68f8c4b312d0cd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/587 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "55d8da75aae24d98802b71dece0e57bb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/738 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8ff5553c1d5f4b46bc035d317c995076", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors.index.json: 0%| | 0.00/29.4k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "40f96c493ea2477992c0f977d540f5a8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading shards: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d36e1146d77b4105b693556aae66f811", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00001-of-00002.safetensors: 0%| | 0.00/4.98G [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e101711685144e8aba3bcd3589fdadf0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00002-of-00002.safetensors: 0%| | 0.00/610M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3d39bded9c394adaa035c883b7b26538", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "771a2968cb03418099f6eb0804c11669", "version_major": 2, "version_minor": 0 }, "text/plain": [ "generation_config.json: 0%| | 0.00/132 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "seed=14\n", "\n", "\n", "if (prev_checkpoint != None):\n", " try:\n", " download_checkpoint(out_name, prev_checkpoint)\n", " except:\n", " pass\n", "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=READ_TOKEN, map_device=\"auto\", add_eos_token=True, use_fast=True)\n", "\n", "new_tokens = {'additional_special_tokens': ['[SQL]','[/SQL]', '[QUESTION]','[/QUESTION]']}\n", "#adicionar tokens especiais:\n", "# if (prev_checkpoint == None):\n", "# tokenizer.add_special_tokens(new_tokens)\n", "\n", "\n", "if torch.cuda.is_bf16_supported():\n", " compute_dtype = torch.bfloat16\n", " attn_implementation = 'flash_attention_2'\n", "else:\n", " compute_dtype = torch.float16\n", " attn_implementation = 'sdpa'\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"\n", "\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=compute_dtype,\n", " bnb_4bit_use_double_quant=False,\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=compute_dtype,\n", " device_map=\"auto\",\n", " quantization_config=bnb_config,\n", "\n", " trust_remote_code=True,\n", " token=READ_TOKEN,\n", " # attn_implementation=attn_implementation\n", ")\n", "\n", "# se adicionar special_tokens tem que fazer resize do tokenizer:\n", "# model.resize_token_embeddings(len(tokenizer))\n", "\n", "## model.resize_token_embeddings(max(len(tokenizer), model.config.vocab_size))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 8479, "status": "ok", "timestamp": 1730570376902, "user": { "displayName": "Leticia Oliveira Silva", "userId": "01512049874517593223" }, "user_tz": 180 }, "id": "fLuqzhSJBvi8", "outputId": "65293049-fcc8-491b-aaca-579bca0686c7" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n", "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n" ] }, { "data": { "text/plain": [ "Embedding(32004, 2048)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# se adicionar special_tokens tem que fazer resize do tokenizer:\n", "#model.resize_token_embeddings(len(tokenizer))\n", "\n", "#model.resize_token_embeddings(max(len(tokenizer), model.config.vocab_size))" ] }, { "cell_type": "markdown", "metadata": { "id": "_I7-bFfm5gqS" }, "source": [ "#### Chat Template - Gerar SQL" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cYVA3Q7ZCzHi" }, "outputs": [], "source": [ "# tokenizer.chat_template = \"\"\"\n", "# {% if messages[0]['role'] == 'system' %}\n", "# {% set loop_messages = messages[1:] %}\n", "# {% set system_message = messages[0]['content'] %}\n", "# {% else %}\n", "# {% set loop_messages = messages %}\n", "# {% set system_message = 'Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.' %}\n", "# {% endif %}\n", "# {{ '# <|system|>/n/' + system_message + '/n//n/' }}\n", "# {% if messages|selectattr(\"role\", \"equalto\", \"example\")|list %}\n", "# Below are some examples of question and their corresponding SQL queries:/n//n/\n", "# {% else %}\n", "# /n/\n", "# {% endif %}\n", "# {% for message in loop_messages %}\n", "# {% if message['role'] == 'example' %}\n", "# {{ message['content'] }}/n//n/\n", "# {% elif message['role'] == 'schema' %}\n", "# # <|schema|>/n/The query will run on a database with the following schema:/n/{{ message['content'] }}/n//n/\n", "# {% elif message['role'] == 'user' %}\n", "# # <|user|>/n/[QUESTION]{{ message['content'] }}[/QUESTION]/n//n/\n", "# {% elif message['role'] == 'assistant' %}\n", "# # <|assistant|>/n/[SQL]{{ message['content'] }}[/SQL]\n", "# {% endif %}\n", "# {% endfor %}\n", "# {% if add_generation_prompt %}\n", "# # <|assistant|>/n/[SQL]\n", "# {% endif %}\n", "# \"\"\".replace(\"\\n\",\"\").replace(\" \", \"\").replace(\"/n/\", \"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IE87Vlt1hrje" }, "outputs": [], "source": [ "import re\n", "\n", "def replace_alias_with_table(query):\n", " # Expressão regular para encontrar tabelas com alias, capturando o nome da tabela e o alias\n", " alias_pattern = re.compile(r'(\\bFROM\\b|\\bJOIN\\b)\\s+(\\w+)\\s+AS\\s+(\\w+)', re.IGNORECASE)\n", "\n", " # Substituições de aliases encontrados no padrão\n", " aliases = {match.group(3): match.group(2) for match in alias_pattern.finditer(query)}\n", "\n", " # Substituir cada alias pelo nome da tabela correspondente\n", " for alias, table in aliases.items():\n", " query = re.sub(r'\\b' + alias + r'\\b', table, query)\n", "\n", " # Remover 'AS' e alias das cláusulas 'FROM' e 'JOIN'\n", " query = re.sub(r'\\bAS\\s+\\w+', '', query, flags=re.IGNORECASE)\n", " return query" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6RDDdVgP5gqT" }, "outputs": [], "source": [ "def to_sql(query):\n", " return sqlparse.format(replace_alias_with_table(query), reindent=True, keyword_case='upper')\n", "\n", "def apply_template(row, tokenizer=tokenizer, n_examplea=0):\n", " question = row['question_en']\n", " schema = row['schema_SQLDatabase_reduzido']\n", " sql = to_sql(row['query'])\n", "\n", " system = \"Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.\"\n", "\n", " chat = [\n", " {'role': 'system', 'content': system},\n", " {'role': 'user', 'content': f\"# Schema:\\n```sql\\n{schema}\\n```\\n\\n# Question: {question}\"},\n", " {'role': 'assistant', 'content': f\"```sql\\n{sql}\\n```\\n\"}\n", " ]\n", "\n", " # chat = [\n", " # {\"role\": \"schema\", \"content\": schema},\n", " # {\"role\": \"user\", \"content\": question},\n", " # {\"role\": \"assistant\", \"content\": sql},\n", " # ]\n", "\n", " row['text'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)\n", "\n", " return row" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2trHEegL5gqU" }, "outputs": [], "source": [ "# spider_chain = json.load(open(\"/content/drive/Shareddrives/LLMs/Datasets/spider/spider_chain_of_thought.json\", \"r\"))\n", "# bird_chain = json.load(open(\"/content/drive/Shareddrives/LLMs/Datasets/bird/bird_chain_of_thought.json\", \"r\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N4jWrC7s5gqU" }, "outputs": [], "source": [ "# df['CoT'] = spider_chain + bird_chain" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bTF0pBsw5gqU" }, "outputs": [], "source": [ "df = df.apply(apply_template, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L4tjUv7o5gqV" }, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DfJvLaGR5gqV" }, "outputs": [], "source": [ "# df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.encode(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vJseOHIu5gqW" }, "outputs": [], "source": [ "# import seaborn as sns\n", "# sns.histplot(df['n_tokens'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 5, "status": "ok", "timestamp": 1731629325148, "user": { "displayName": "LLeticia_o@hotmail.com", "userId": "07103443173840805964" }, "user_tz": 180 }, "id": "PIvSnr6Y5gqW", "outputId": "2d3c87b0-eafc-487a-d563-8dc708283d32" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema.<|im_end|>\n", "<|im_start|>user\n", "# Schema:\n", "```sql\n", "CREATE TABLE city (\n", " City_ID INT,\n", " Status TEXT,\n", " Population REAL,\n", " PRIMARY KEY (City_ID)\n", ");\n", "```\n", "\n", "# Question: Show the status shared by cities with population bigger than 1500 and smaller than 500.<|im_end|>\n", "<|im_start|>assistant\n", "```sql\n", "SELECT Status\n", "FROM city\n", "WHERE Population > 1500 INTERSECT\n", " SELECT Status\n", " FROM city WHERE Population < 500\n", "```\n", "<|im_end|>\n", "\n" ] } ], "source": [ "print(df['text'][df.index[50]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "roZzKNOj5gqW" }, "outputs": [], "source": [ "_df = pd.DataFrame(columns=['text'])\n", "_df['text'] = df.sample(frac=1, random_state=14).reset_index(drop=True)['text']\n", "_df = Dataset.from_pandas(_df)\n", "_df = _df.train_test_split(test_size=0.01, shuffle=True, seed=14)\n", "train_dataset, valid_dataset = _df[\"train\"], _df[\"test\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "b6mjOblXeMup" }, "source": [ "#### Chat Template - Schema Linking" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "opbdu1g6eMuq" }, "outputs": [], "source": [ "def apply_template(row, tokenizer=tokenizer, n_examplea=0):\n", " question = row['question_en']\n", " schema = row['schema_SQLDatabase']\n", " schema_linking = row['selector_correct']\n", "\n", " system = \"Given a user question and the schema of a database, your task is to generate an JSON with the the names of tables and columns of the schema that the question is referring to.\"\n", "\n", " chat = [{'role': 'user', 'content': f\"# System:\\n{system}\\n\\n# Schema:\\n```sql\\n{schema}\\n```\\n\\n# Question: {question}\"},\n", " {'role': 'assistant', 'content': f\"```json\\n{schema_linking}\\n```\"}\n", " ]\n", "\n", " row['text'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)\n", "\n", " return row" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_HHKJ9VeMur" }, "outputs": [], "source": [ "df = df.apply(apply_template, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oF0-4l8DeMus" }, "outputs": [], "source": [ "# df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.encode(x)))\n", "# import seaborn as sns\n", "# sns.histplot(df['n_tokens'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1731937378053, "user": { "displayName": "HighSchoolMinas - MinasCoders", "userId": "06806689059039259777" }, "user_tz": 180 }, "id": "EFY_pTEteMut", "outputId": "a71630ce-da91-4aac-e8ca-18886271f22d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|im_start|>system\n", "You are a helpful assistant.<|im_end|>\n", "<|im_start|>user\n", "# System:\n", "Given a user question and the schema of a database, your task is to generate an JSON with the the names of tables and columns of the schema that the question is referring to.\n", "\n", "# Schema:\n", "```sql\n", "CREATE TABLE Addresses (\n", " address_id INTEGER,\n", " line_1 VARCHAR(80),\n", " line_2 VARCHAR(80),\n", " city VARCHAR(50),\n", " zip_postcode CHAR(20),\n", " state_province_county VARCHAR(50),\n", " country VARCHAR(50),\n", " PRIMARY KEY (address_id)\n", ");\n", "\n", "CREATE TABLE People (\n", " person_id INTEGER,\n", " first_name VARCHAR(255),\n", " middle_name VARCHAR(255),\n", " last_name VARCHAR(255),\n", " cell_mobile_number VARCHAR(40),\n", " email_address VARCHAR(40),\n", " login_name VARCHAR(40),\n", " password VARCHAR(40),\n", " PRIMARY KEY (person_id)\n", ");\n", "\n", "CREATE TABLE Students (\n", " student_id INTEGER,\n", " student_details VARCHAR(255),\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (student_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Courses (\n", " course_id VARCHAR(100),\n", " course_name VARCHAR(120),\n", " course_description VARCHAR(255),\n", " other_details VARCHAR(255),\n", " PRIMARY KEY (course_id)\n", ");\n", "\n", "CREATE TABLE People_Addresses (\n", " person_address_id INTEGER,\n", " person_id INTEGER,\n", " address_id INTEGER,\n", " date_from DATETIME,\n", " date_to DATETIME,\n", " PRIMARY KEY (person_address_id),\n", " FOREIGN KEY (address_id) REFERENCES Addresses(address_id),\n", " FOREIGN KEY (person_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Student_Course_Registrations (\n", " student_id INTEGER,\n", " course_id INTEGER,\n", " registration_date DATETIME,\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (course_id) REFERENCES Courses(course_id),\n", " FOREIGN KEY (student_id) REFERENCES Students(student_id)\n", ");\n", "\n", "CREATE TABLE Student_Course_Attendance (\n", " student_id INTEGER,\n", " course_id INTEGER,\n", " date_of_attendance DATETIME,\n", " PRIMARY KEY (student_id),\n", " FOREIGN KEY (student_id) REFERENCES Student_Course_Registrations(student_id),\n", " FOREIGN KEY (course_id) REFERENCES Student_Course_Registrations(course_id)\n", ");\n", "\n", "CREATE TABLE Candidates (\n", " candidate_id INTEGER,\n", " candidate_details VARCHAR(255),\n", " PRIMARY KEY (candidate_id),\n", " FOREIGN KEY (candidate_id) REFERENCES People(person_id)\n", ");\n", "\n", "CREATE TABLE Candidate_Assessments (\n", " candidate_id INTEGER,\n", " qualification CHAR(15),\n", " assessment_date DATETIME,\n", " asessment_outcome_code CHAR(15),\n", " PRIMARY KEY (candidate_id),\n", " FOREIGN KEY (candidate_id) REFERENCES Candidates(candidate_id)\n", ");\n", "```\n", "\n", "# Question: How many students are attending English courses?<|im_end|>\n", "<|im_start|>assistant\n", "```json\n", "{\n", " 'Courses': ['course_id', 'course_name'],\n", " 'Student_Course_Attendance': ['student_id', 'course_id']\n", "}\n", "```<|im_end|>\n", "\n" ] } ], "source": [ "print(df['text'][df.index[70]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "puYY-BqFeMuu" }, "outputs": [], "source": [ "_df = pd.DataFrame(columns=['text'])\n", "_df['text'] = df.sample(frac=1, random_state=14).reset_index(drop=True)['text']\n", "_df = Dataset.from_pandas(_df)\n", "_df = _df.train_test_split(test_size=0.01, shuffle=True, seed=14)\n", "train_dataset, valid_dataset = _df[\"train\"], _df[\"test\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "DWpXeuO_KlLS" }, "source": [ "### Finetuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0oVpZDj1AXY9" }, "outputs": [], "source": [ "from huggingface_hub import login, create_repo\n", "from google.colab import userdata\n", "import wandb\n", "import os\n", "\n", "#token = userdata.get('hf_write')\n", "token = WRITE_TOKEN\n", "login(token=token)\n", "set_seed(1234)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KRhO7UJ-Q4Y8" }, "outputs": [], "source": [ "def find_all_linear_names(model, new_tokens=False):\n", " lora_module_names = set()\n", " for name, module in model.named_modules():\n", " if isinstance(module, bnb.nn.Linear4bit) or isinstance(module, bnb.nn.Linear8bitLt):\n", " names = name.split(\".\")\n", " lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n", " if(new_tokens):\n", " lora_module_names.add(\"lm_head\")\n", " return list(lora_module_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 503, "status": "ok", "timestamp": 1731937441903, "user": { "displayName": "HighSchoolMinas - MinasCoders", "userId": "06806689059039259777" }, "user_tz": 180 }, "id": "L0qqP5Y9PtRh", "outputId": "4d415c71-e1cb-4bf1-ef73-d6d76cd35d44" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 7 modules to quantize: ['up_proj', 'v_proj', 'q_proj', 'k_proj', 'o_proj', 'gate_proj', 'down_proj']\n" ] } ], "source": [ "modules = find_all_linear_names(model)\n", "print(f\"Found {len(modules)} modules to quantize: {modules}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uFUnJrbjPwAT" }, "outputs": [], "source": [ "peft_config = LoraConfig(\n", " lora_alpha=128, #primeira versão = 16\n", " lora_dropout=0.1,\n", " r=64,\n", " # bias=\"none\",\n", " # task_type=\"CAUSAL_LM\",\n", " target_modules=modules,\n", " # modules_to_save=[\"embed_tokens\"], #quando adicionar tokens speciais\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "buh0o2P2jwbx" }, "outputs": [], "source": [ "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true, "base_uri": "https://localhost:8080/", "height": 736, "referenced_widgets": [ "13509555b5c3489dbfbe64f2f6f284df", "1ac85930f80f46f0bace12efed6ef9cb", "80d438af0fb44b479f1a2ccfa15f27fc", "7833338232914b5bb8144c8844d4b4e5", "7bd5866d6064425e8d4604279f47ad7a", "a3802f89dba94a3697d534522a2bb786", "2afb99816a8e4029966f3b84c7792931", "0ea6db73f2444a7d92e39d8b3725a954", "06888bf0c7ac4d9fbde3b2e66de722fa", "e8e7905f02ec444ea5c10d4b3846534a", "0d690a90e84e41cbb12e37301e8adadb", "2fe54466218640b3acaefe02890bffbd", "41360284ac634c8081df2b01449c2ca7", "7faceca502ed4af091819f5b15688569", "c63a4ad82fa946a3be5e39f61c5faba4", "2b69aa1e7f5f4c6d9995ccdbbe7e8238", "195abb3ae03744bebacbfd2dbe874e19", "535755d5b72d4ca49e469007e1663101", "fa38563f9546433597730c810187afbd", "80a51b467dda4e6293fd2e560badaa29", "93498cdfa1f248b98245086dc49cbfa1", "cf196d5297fa4e6d91a59729eeb2ef16" ] }, "id": "9bD7ea0F-GQn", "outputId": "1392d82f-e296-49ca-ea27-50030d6adb5c" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '0.13.0'.\n", "\n", "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "13509555b5c3489dbfbe64f2f6f284df", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/8569 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fe54466218640b3acaefe02890bffbd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/87 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
250 | \n", "0.276800 | \n", "0.150541 | \n", "
500 | \n", "0.109200 | \n", "0.114981 | \n", "
750 | \n", "0.094500 | \n", "0.104744 | \n", "
1000 | \n", "0.086200 | \n", "0.101263 | \n", "
"
],
"text/plain": [
"