{ "cells": [ { "cell_type": "markdown", "source": [ "**종속성 설치**" ], "metadata": { "id": "QGv6ld9q1XHv" }, "id": "QGv6ld9q1XHv" }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "nDZe_wqKU6J3", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "5213a3aa-dcf4-47af-b508-dcb224a701d1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", "Collecting peft\n", " Downloading peft-0.7.0-py3-none-any.whl (168 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.3/168.3 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting accelerate\n", " Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting optimum\n", " Downloading optimum-1.15.0-py3-none-any.whl (400 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.9/400.9 kB\u001b[0m \u001b[31m43.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting bitsandbytes\n", " Downloading bitsandbytes-0.41.3-py3-none-any.whl (92.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting trl\n", " Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting wandb\n", " Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m74.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting einops\n", " Downloading einops-0.7.0-py3-none-any.whl (44 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n", "Collecting coloredlogs (from optimum)\n", " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n", "Collecting datasets (from optimum)\n", " Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m52.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tyro>=0.5.11 (from trl)\n", " Downloading tyro-0.6.0-py3-none-any.whl (100 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", " Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.6/190.6 kB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentry-sdk>=1.0.0 (from wandb)\n", " Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m29.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", "Collecting setproctitle (from wandb)\n", " Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n", "Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers)\n", " Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n", " Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n", "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n", " Downloading shtab-1.6.5-py3-none-any.whl (13 kB)\n", "Collecting humanfriendly>=9.1 (from coloredlogs->optimum)\n", " Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n", "Collecting pyarrow-hotfix (from datasets->optimum)\n", " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets->optimum)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n", "Collecting multiprocess (from datasets->optimum)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n", "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Installing collected packages: sentencepiece, bitsandbytes, smmap, shtab, setproctitle, sentry-sdk, pyarrow-hotfix, humanfriendly, einops, docstring-parser, docker-pycreds, dill, multiprocess, gitdb, coloredlogs, tyro, GitPython, accelerate, wandb, datasets, trl, peft, optimum\n", "Successfully installed GitPython-3.1.40 accelerate-0.25.0 bitsandbytes-0.41.3 coloredlogs-15.0.1 datasets-2.15.0 dill-0.3.7 docker-pycreds-0.4.0 docstring-parser-0.15 einops-0.7.0 gitdb-4.0.11 humanfriendly-10.0 multiprocess-0.70.15 optimum-1.15.0 peft-0.7.0 pyarrow-hotfix-0.6 sentencepiece-0.1.99 sentry-sdk-1.38.0 setproctitle-1.3.3 shtab-1.6.5 smmap-5.0.1 trl-0.7.4 tyro-0.6.0 wandb-0.16.1\n" ] } ], "source": [ "pip install transformers peft accelerate optimum bitsandbytes trl wandb einops" ], "id": "nDZe_wqKU6J3" }, { "cell_type": "markdown", "source": [ "라이브러리 및 모듈 임포트" ], "metadata": { "id": "hNUtFKCYBGm4" }, "id": "hNUtFKCYBGm4" }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d", "outputId": "ed07918d-f1c3-4772-bcf5-f8b0926025a1" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", " warnings.warn(\n" ] } ], "source": [ "import os\n", "from dataclasses import dataclass, field\n", "from typing import Optional\n", "import re\n", "\n", "import torch\n", "import tyro\n", "from accelerate import Accelerator\n", "from datasets import load_dataset, Dataset\n", "from peft import AutoPeftModelForCausalLM, LoraConfig\n", "from tqdm import tqdm\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", ")\n", "\n", "from trl import SFTTrainer\n", "\n", "from trl.trainer import ConstantLengthDataset" ], "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d" }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 158, "referenced_widgets": [ "0f44dbac37424152ac0bdef0e16e4771", "eec756ae16fa497aa0c5fd9670a2862c", "897a4815dd0b447480a212ff81cf1069", "c56feaeacd2d4c858cf12370155c90d9", "b1fdb55628ef4deab74333414838e8a1", "c2cf32bbc86e47558b643061ee259207", "5a4e42f28fd541b2b93fa85476f23935", "19a4a603811e4c12a8b71e7ab57367ef", "dd8044969f52407ab3951771a536487d", "e3c3bf1e8ce64b61a317968f40b7b98c", "54880464a6974f5ba683407cf3d99faf", "50bfb0ff6a1044d78696e23f21326dc4", "18918393a100494c98772f5221404a1f", "c0a16708a5bb40a49704f603c870de5b", "a4463f0298f443d38d6f06cdd3f95f8c", "5922efe9dbff464393ed1e76685093bf", "e29e2fd443574c69af022dfef253d587", "e9f8a2f3e5434244afdd164f07e79a53", "698acc9f74574f6b9d8f22b3b41b7392", "436b45f7309944acb95fad6e5ce0bf13", "f9e3abef808446c388b2a8132121dcdb", "5259d666361641e09449aa903361124a", "301e40018b4c4262a1b9d3948e95afeb", "30140a0a2746420280dd754b9acde5e2", "34da297abe3c4e938686423e8825e1bf", "305f83d3aa624b909e76ddbd9425636c", "5aa6ad99fd754f259da4754502680baf", "d0e8087a02154e3c954d38cb2136810c", "06c4d7ca241e4a0e8fe67a631781c4d7", "ccdbd1df3467489094425677f51c0325", "e923319ff4ec4556b543f78d84f865f4", "97e0a8143f484c0d8d34a59535f818b0" ] }, "id": "tX7gYxZaVhYL", "outputId": "15f8a967-15ea-45fe-ecb9-91382c6fd9bb" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
/content/wandb/run-20231210_131434-p3xsnu5x
"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Step | \n", "Training Loss | \n", "
---|---|
50 | \n", "1.922300 | \n", "
100 | \n", "1.110100 | \n", "
150 | \n", "1.057100 | \n", "
200 | \n", "1.045300 | \n", "
250 | \n", "1.053800 | \n", "
300 | \n", "1.049200 | \n", "
350 | \n", "1.065100 | \n", "
400 | \n", "1.053200 | \n", "
450 | \n", "1.058000 | \n", "
500 | \n", "1.034800 | \n", "
550 | \n", "1.064300 | \n", "
600 | \n", "0.976800 | \n", "
650 | \n", "0.995700 | \n", "
700 | \n", "0.987800 | \n", "
750 | \n", "1.003000 | \n", "
800 | \n", "1.010500 | \n", "
"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/trl/trainer/utils.py:570: UserWarning: The dataset reached end and the iterator is reset to the start.\n",
" warnings.warn(\"The dataset reached end and the iterator is reset to the start.\")\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=800, training_loss=1.0929430866241454, metrics={'train_runtime': 2003.4806, 'train_samples_per_second': 0.799, 'train_steps_per_second': 0.399, 'total_flos': 1.65609039986688e+16, 'train_loss': 1.0929430866241454, 'epoch': 0.32})"
]
},
"metadata": {},
"execution_count": 30
}
],
"source": [
"trainer.train()"
],
"id": "14019fa9-0c6f-4729-ac99-0d407af375b8"
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "3Y4FQSyRghQt",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "9b8595ae-cb07-4961-8e08-0c8093adebeb"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'/gdrive/MyDrive/nlp/lora-midm-7b-nsmc-v1.2'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 31
}
],
"source": [
"script_args.training_args.output_dir"
],
"id": "3Y4FQSyRghQt"
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
"outputs": [],
"source": [
"trainer.save_model(script_args.training_args.output_dir)"
],
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
{
"cell_type": "markdown",
"metadata": {
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
"source": [
"# 추론 테스트"
],
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
"outputs": [],
"source": [
"from transformers import pipeline, TextStreamer"
],
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
"outputs": [],
"source": [
"instruction_prompt_template = \"\"\"###System;다음은 네이버 영화 리뷰들을 모아놓은 문장이다. 이를 분석하여 사용자가 작성한 영화 리뷰의 감정을 긍정 또는 부정으로 예측하라.\n",
"\n",
"### 리뷰 내용: {0} ### 분석 결과:\n",
"\"\"\"\n",
"\n",
"prompt_template = \"\"\"###System;{System}\n",
"###User;{User}\n",
"###Midm;\"\"\"\n",
"\n",
"default_system_msg = (\n",
" \"너는 먼저 사용자가 작성한 영화 리뷰의 감정을 분석하는 에이전트이다. 이로부터 주어진 영화 리뷰에 대한 긍정 또는 부정을 추출해야 한다.\"\n",
")"
],
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
"outputs": [],
"source": [
"def wrapper_generate(model, input_prompt, do_stream=False):\n",
" data = tokenizer(input_prompt, return_tensors=\"pt\")\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
" input_ids = data.input_ids[..., :-1]\n",
" with torch.no_grad():\n",
" pred = model.generate(\n",
" #input_ids=input_ids.cuda(),\n",
" input_ids = input_ids.to('cuda'),\n",
" streamer=streamer if do_stream else None,\n",
" use_cache=True,\n",
" max_new_tokens=float('inf'),\n",
" do_sample=False\n",
" )\n",
" decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n",
" decoded_text = decoded_text[0].replace(\"<[!newline]>\", \"\\n\")\n",
" return (decoded_text[len(input_prompt):])"
],
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
{
"cell_type": "code",
"source": [
"eval_dic = {i:wrapper_generate(model=base_model, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"
],
"metadata": {
"id": "aO5N5XX4BjFr"
},
"id": "aO5N5XX4BjFr",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"미세튜닝된 모델 테스트"
],
"metadata": {
"id": "DKeAPc9NAs4F"
},
"id": "DKeAPc9NAs4F"
},
{
"cell_type": "code",
"source": [
"from transformers import pipeline, TextStreamer"
],
"metadata": {
"id": "XMU2ydLG2k3Z"
},
"id": "XMU2ydLG2k3Z",
"execution_count": 38,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"프롬프트 정의"
],
"metadata": {
"id": "e-SkE9mf2sWx"
},
"id": "e-SkE9mf2sWx"
},
{
"cell_type": "code",
"source": [
"instruction_prompt_template = \"\"\"###System;다음은 네이버 영화 리뷰들을 모아놓은 문장이다. 이를 분석하여 사용자가 작성한 영화 리뷰의 감정을 긍정 또는 부정으로 예측하라.\n",
"\n",
"### 리뷰 내용: {0} ### 분석 결과:\n",
"\"\"\"\n",
"\n",
"prompt_template = \"\"\"###System;{System}\n",
"###User;{User}\n",
"###Midm;\"\"\"\n",
"\n",
"default_system_msg = (\n",
" \"너는 사용자가 작성한 리뷰의 긍정 또는 부정을 판단해야 한다.\"\n",
")"
],
"metadata": {
"id": "XUvbhk6t2pXs"
},
"id": "XUvbhk6t2pXs",
"execution_count": 39,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"사용자 영화 리뷰에 대한 입력 프롬프트 생성"
],
"metadata": {
"id": "p4J4p1b62yAM"
},
"id": "p4J4p1b62yAM"
},
{
"cell_type": "code",
"source": [
"def wrapper_generate(model, input_prompt, do_stream=False):\n",
" data = tokenizer(input_prompt, return_tensors=\"pt\")\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
" input_ids = data.input_ids[..., :-1]\n",
" with torch.no_grad():\n",
" pred = model.generate(\n",
" input_ids=input_ids.cuda(),\n",
" streamer=streamer if do_stream else None,\n",
" use_cache=True,\n",
" max_new_tokens=float('inf'),\n",
" do_sample=False\n",
" )\n",
" decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n",
" decoded_text = decoded_text[0].replace(\"<[!newline]>\", \"\\n\")\n",
" return (decoded_text[len(input_prompt):])"
],
"metadata": {
"id": "iov1rlPJ21K-"
},
"id": "iov1rlPJ21K-",
"execution_count": 40,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"양자화 설정: 4비트로 양자화함"
],
"metadata": {
"id": "Vs2CLd2D24ha"
},
"id": "Vs2CLd2D24ha"
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
"outputs": [],
"source": [
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
")"
],
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
{
"cell_type": "markdown",
"source": [
"미세튜닝된 모델을 4비트로 양자화함"
],
"metadata": {
"id": "q6n3Tr4L27kR"
},
"id": "q6n3Tr4L27kR"
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"id": "39db2ee4-23c8-471f-89b2-bca34964bf81",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 576
},
"outputId": "92593767-43aa-4670-9358-6b71c456365f"
},
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mtrained_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad_token_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpad_token_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mtrained_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbos_token_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbos_token_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'trained_model' is not defined"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\n",
" script_args.model_name,\n",
" trust_remote_code=True,\n",
" cache_dir=script_args.cache_dir,\n",
")\n",
"\n",
"if getattr(tokenizer, \"pad_token\", None) is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = \"right\" # Fix weird overflow issue with fp16 training\n",
"\n",
"tokenizer.add_special_tokens(dict(bos_token=''))\n",
"\n",
"trained_model.config.pad_token_id = tokenizer.pad_token_id\n",
"trained_model.config.bos_token_id = tokenizer.bos_token_id"
],
"id": "b0b75ca4-730d-4bde-88bb-a86462a76d52"
},
{
"cell_type": "markdown",
"source": [
"데이터셋 생성"
],
"metadata": {
"id": "SIHMdQmKADor"
},
"id": "SIHMdQmKADor"
},
{
"cell_type": "code",
"source": [
"valid_dataset = create_valid_datasets(tokenizer, script_args)"
],
"metadata": {
"id": "a_uxsZ4U3EHP"
},
"id": "a_uxsZ4U3EHP",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e374555b-9f8a-4617-8ea7-c1e6ee1b2999"
},
"outputs": [],
"source": [
"eval_dic = {i: wrapper_generate(model=trained_model, input_prompt=prompt_template.format(System=default_system_msg, User=example[\"document\"]))for i, example in enumerate(valid_dataset)}"
],
"id": "e374555b-9f8a-4617-8ea7-c1e6ee1b2999"
},
{
"cell_type": "markdown",
"source": [
"모델 결과"
],
"metadata": {
"id": "O50xByhv3RoQ"
},
"id": "O50xByhv3RoQ"
},
{
"cell_type": "code",
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, classification_report\n",
"\n",
"true_labels = [example[\"label\"] for example in valid_dataset]\n",
"\n",
"predicted_labels = [1 if \"긍정\" in eval_dic[i] else 0 for i in range(len(valid_dataset))]\n",
"\n",
"conf_matrix = confusion_matrix(true_labels, predicted_labels)\n",
"\n",
"accuracy = accuracy_score(true_labels, predicted_labels)\n",
"\n",
"precision = precision_score(true_labels, predicted_labels)\n",
"recall = recall_score(true_labels, predicted_labels)\n",
"f1 = f1_score(true_labels, predicted_labels)\n",
"\n",
"print(\"Precision:\", precision)\n",
"print(\"Recall:\", recall)\n",
"print(\"F1 Score:\", f1, \"\\n\")\n",
"\n",
"# 분류 리포트 출력\n",
"class_report = classification_report(true_labels, predicted_labels, target_names=['Negative', 'Positive'])\n",
"print(\"Classification Report:\\n\", class_report)"
],
"metadata": {
"id": "TL7zOjZD3T14"
},
"id": "TL7zOjZD3T14",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"![image.png]()"
],
"metadata": {
"id": "9MH1yeHHB6Ly"
},
"id": "9MH1yeHHB6Ly"
},
{
"cell_type": "markdown",
"source": [
"![image.png]()"
],
"metadata": {
"id": "xQd8rk9VBM3L"
},
"id": "xQd8rk9VBM3L"
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade huggingface_hub\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "E5cdHlN0DBZ7",
"outputId": "d1dd4e38-84ab-4c56-fe0b-080edffb9ca6"
},
"id": "E5cdHlN0DBZ7",
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.19.4)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (3.13.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2023.6.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.66.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.5.0)\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (23.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2023.11.17)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from huggingface_hub import Repository\n",
"import os\n",
"\n",
"notebook_file = \"okdol/hw-midm-7B-nsmc.ipynb\"\n",
"repo_name = \"okdol/hw-midm-7B-nsmc\"\n",
"repo_url = f\"https://huggingface.co/okdol/hw-midm-7B-nsmc\"\n",
"\n",
"\n",
"# 로컬에 리포지토리 디렉토리 생성\n",
"os.makedirs(repo_name, exist_ok=True)\n",
"\n",
"# 리포지토리 초기화\n",
"repo = Repository(local_dir=repo_name, clone_from=repo_url)\n",
"\n",
"# 노트북 파일을 리포지토리 디렉토리로 복사\n",
"notebook_path_in_repo = os.path.join(repo_name, os.path.basename(notebook_file))\n",
"os.replace(notebook_file, notebook_path_in_repo)\n",
"\n",
"# Hugging Face Hub에 푸시\n",
"repo.push_to_hub()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 388
},
"id": "UC-Ymv8pCaEr",
"outputId": "a6aaaa63-e5d4-45dd-e5cd-03d1fb031d11"
},
"id": "UC-Ymv8pCaEr",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:127: FutureWarning: 'Repository' (from 'huggingface_hub.repository') is deprecated and will be removed from version '1.0'. Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete removal is only planned on next major release.\n",
"For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.\n",
" warnings.warn(warning_message, FutureWarning)\n",
"/content/okdol/hw-midm-7B-nsmc is already a clone of https://huggingface.co/okdol/hw-midm-7B-nsmc. Make sure you pull the latest changes with `repo.git_pull()`.\n",
"WARNING:huggingface_hub.repository:/content/okdol/hw-midm-7B-nsmc is already a clone of https://huggingface.co/okdol/hw-midm-7B-nsmc. Make sure you pull the latest changes with `repo.git_pull()`.\n"
]
},
{
"output_type": "error",
"ename": "FileNotFoundError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.