{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: huggingface_hub in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (0.23.3)\n", "Requirement already satisfied: filelock in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (3.14.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (2024.3.1)\n", "Requirement already satisfied: packaging>=20.9 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n", "Requirement already satisfied: requests in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (4.66.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface_hub) (4.12.0rc1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests->huggingface_hub) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests->huggingface_hub) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests->huggingface_hub) (2.2.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests->huggingface_hub) (2024.6.2)\n", "Requirement already satisfied: datasets in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (2.19.2)\n", "Requirement already satisfied: peft in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (0.11.1)\n", "Requirement already satisfied: transformers[torch] in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (4.41.2)\n", "Requirement already satisfied: filelock in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (3.14.0)\n", "Requirement already satisfied: numpy>=1.17 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (1.25.0)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (16.1.0)\n", "Requirement already satisfied: pyarrow-hotfix in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: requests>=2.32.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.62.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (4.66.4)\n", "Requirement already satisfied: xxhash in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2024.3.1)\n", "Requirement already satisfied: aiohttp in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (3.9.5)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (0.23.3)\n", "Requirement already satisfied: packaging in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: psutil in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from peft) (5.9.0)\n", "Requirement already satisfied: torch>=1.13.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from peft) (1.13.1)\n", "Requirement already satisfied: accelerate>=0.21.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from peft) (0.30.1)\n", "Requirement already satisfied: safetensors in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from peft) (0.4.3)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from transformers[torch]) (2024.5.15)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from transformers[torch]) (0.19.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from huggingface-hub>=0.21.2->datasets) (4.12.0rc1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (2.2.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (2024.6.2)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch>=1.13.0->peft) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch>=1.13.0->peft) (8.5.0.96)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch>=1.13.0->peft) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch>=1.13.0->peft) (11.7.99)\n", "Requirement already satisfied: setuptools in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft) (69.5.1)\n", "Requirement already satisfied: wheel in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->peft) (0.43.0)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Requirement already satisfied: flash-attn in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (2.5.9.post1)\n", "Requirement already satisfied: torch in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from flash-attn) (1.13.1)\n", "Requirement already satisfied: einops in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from flash-attn) (0.8.0)\n", "Requirement already satisfied: typing-extensions in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch->flash-attn) (4.12.0rc1)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch->flash-attn) (11.7.99)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch->flash-attn) (8.5.0.96)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch->flash-attn) (11.10.3.66)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from torch->flash-attn) (11.7.99)\n", "Requirement already satisfied: setuptools in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->flash-attn) (69.5.1)\n", "Requirement already satisfied: wheel in /home/xinzheng/anaconda3/envs/speech-BCI-new/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch->flash-attn) (0.43.0)\n" ] } ], "source": [ "!pip install huggingface_hub\n", "!pip install -U datasets peft transformers[torch]\n", "!pip install -q bitsandbytes trl accelerate\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "import re\n", "from pprint import pprint\n", " \n", "import pandas as pd\n", "import torch\n", "from datasets import Dataset, load_dataset\n", "from huggingface_hub import notebook_login\n", "from peft import LoraConfig, PeftModel\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", ")\n", "from trl import SFTTrainer\n", "import re" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "torch.cuda.set_per_process_memory_fraction(0.8) " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "66d43524d1a04d309785e243f6f016a8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
4 | \n", "1.981200 | \n", "1.774238 | \n", "
8 | \n", "1.753500 | \n", "1.754978 | \n", "
12 | \n", "1.772600 | \n", "1.726179 | \n", "
16 | \n", "1.557500 | \n", "1.692598 | \n", "
20 | \n", "1.790200 | \n", "1.662171 | \n", "
24 | \n", "1.598900 | \n", "1.633827 | \n", "
28 | \n", "1.526500 | \n", "1.615106 | \n", "
32 | \n", "1.560600 | \n", "1.600001 | \n", "
36 | \n", "1.618500 | \n", "1.588000 | \n", "
40 | \n", "1.646600 | \n", "1.579654 | \n", "
44 | \n", "1.657100 | \n", "1.567710 | \n", "
48 | \n", "1.587900 | \n", "1.558632 | \n", "
52 | \n", "1.397700 | \n", "1.550794 | \n", "
56 | \n", "1.704600 | \n", "1.543783 | \n", "
60 | \n", "1.456500 | \n", "1.538500 | \n", "
64 | \n", "1.901600 | \n", "1.532189 | \n", "
68 | \n", "1.774300 | \n", "1.528996 | \n", "
72 | \n", "1.390000 | \n", "1.524170 | \n", "
76 | \n", "1.558800 | \n", "1.519067 | \n", "
80 | \n", "1.627400 | \n", "1.517108 | \n", "
84 | \n", "1.504100 | \n", "1.512800 | \n", "
88 | \n", "1.668200 | \n", "1.509464 | \n", "
92 | \n", "1.526700 | \n", "1.505236 | \n", "
96 | \n", "1.618400 | \n", "1.503344 | \n", "
100 | \n", "1.451900 | \n", "1.499353 | \n", "
104 | \n", "1.644900 | \n", "1.496035 | \n", "
108 | \n", "1.469000 | \n", "1.492282 | \n", "
112 | \n", "1.614600 | \n", "1.489366 | \n", "
116 | \n", "1.591700 | \n", "1.487346 | \n", "
120 | \n", "1.487500 | \n", "1.482805 | \n", "
124 | \n", "1.416000 | \n", "1.480361 | \n", "
128 | \n", "1.313600 | \n", "1.481161 | \n", "
132 | \n", "1.334400 | \n", "1.479421 | \n", "
136 | \n", "1.471800 | \n", "1.476773 | \n", "
140 | \n", "1.540500 | \n", "1.474109 | \n", "
144 | \n", "1.452700 | \n", "1.473360 | \n", "
148 | \n", "1.323000 | \n", "1.472112 | \n", "
152 | \n", "1.527600 | \n", "1.470621 | \n", "
156 | \n", "1.535100 | \n", "1.469403 | \n", "
160 | \n", "1.356000 | \n", "1.467490 | \n", "
164 | \n", "1.492700 | \n", "1.465348 | \n", "
168 | \n", "1.371600 | \n", "1.464317 | \n", "
172 | \n", "1.628700 | \n", "1.463003 | \n", "
176 | \n", "1.242100 | \n", "1.462533 | \n", "
180 | \n", "1.284400 | \n", "1.461138 | \n", "
184 | \n", "1.563000 | \n", "1.459591 | \n", "
188 | \n", "1.421000 | \n", "1.457585 | \n", "
192 | \n", "1.208200 | \n", "1.456179 | \n", "
196 | \n", "1.350800 | \n", "1.454647 | \n", "
200 | \n", "1.602600 | \n", "1.454009 | \n", "
\n", "