{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "28e4c4d1-a73f-437b-a1bd-c2cc3874924a" }, "source": [ "# 강의 11주차: midm-food-order-understanding\n", "\n", "1. KT-AI/midm-bitext-S-7B-inst-v1 를 주문 문장 이해에 미세 튜닝\n", "\n", "- food-order-understanding-small-3200.json (학습)\n", "- food-order-understanding-small-800.json (검증)\n", "\n", "\n", "종속적인 필요 내용\n", "- huggingface 계정 설정 및 llama-2 사용 승인\n", "- 로깅을 위한 wandb (log 기록됨)" ], "id": "28e4c4d1-a73f-437b-a1bd-c2cc3874924a" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nDZe_wqKU6J3", "outputId": "184c46fb-9706-4e52-9193-a73c9e8eac50" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.7.0)\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.25.0)\n", "Requirement already satisfied: optimum in /usr/local/lib/python3.10/dist-packages (1.15.0)\n", "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.41.3)\n", "Requirement already satisfied: trl in /usr/local/lib/python3.10/dist-packages (0.7.4)\n", "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (0.16.1)\n", "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (0.7.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n", "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from optimum) (15.0.1)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from optimum) (2.15.0)\n", "Requirement already satisfied: tyro>=0.5.11 in /usr/local/lib/python3.10/dist-packages (from trl) (0.6.0)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.1.40)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.38.0)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb) (1.3.3)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n", "Requirement already satisfied: sentencepiece!=0.1.92,>=0.1.91 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.1.99)\n", "Requirement already satisfied: docstring-parser>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (0.15)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n", "Requirement already satisfied: shtab>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (1.6.5)\n", "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->optimum) (10.0)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.6)\n", "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.3.7)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (0.70.15)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n" ] } ], "source": [ "pip install transformers peft accelerate optimum bitsandbytes trl wandb einops" ], "id": "nDZe_wqKU6J3" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d", "outputId": "480d3d43-f54e-45eb-b1d9-3659324d55f9" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", " warnings.warn(\n" ] } ], "source": [ "import os\n", "from dataclasses import dataclass, field\n", "from typing import Optional\n", "import re\n", "\n", "import torch\n", "import tyro\n", "from accelerate import Accelerator\n", "from datasets import load_dataset, Dataset\n", "from peft import AutoPeftModelForCausalLM, LoraConfig\n", "from tqdm import tqdm\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", ")\n", "\n", "from trl import SFTTrainer\n", "\n", "from trl.trainer import ConstantLengthDataset" ], "id": "51eb00d7-2928-41ad-9ae9-7f0da7d64d6d" }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "39f60d5965554427af7d777ddfdb5c6e", "830abf25e3284b83b50259838de25461", "e3be9c1a41484139826db7c7e8bad684", "1b19e49a7fbb474c82f1608ec5cf01be", "47e6f26a412641d4bd7aee815075d817", "9c9ef3925e754ae2bd2948517c87ae8d", "2122301e2a8f481ebb22f718d5d38576", "84b344d92503448bafad168a2bcf70a3", "4225ce6045154131b73498b1b7b55d6d", "ce039ec952f14f35957b9b1f67242e87", "de82245c43a84ac992b7d99fcc360bdc", "1c379cd72c454368b2d7fa8f13d81c14", "f1960c78e0b9417a9d1dbcfac32be626", "14245abecc01427bb58bb10ae83df2e1", "ff827de069434301a4c7510d90d1ff5f", "6103bae35e6843ae93c45b4ac1433f3a", "105de6d562674a28a0817bdc4c933551", "155bef4539784268a616235b016d5309", "f50840f61e1e414094a9d7951e7d3ce8", "16821a4867544d84bcf301a95a56e83f", "868228753b36406daa2dee41b0dd3874", "452c00fbce3b425f940a72707a8792cc", "7dbdbc659e17416ab15171fe0379184f", "72654bb34fb04051bd9dcce3abf6188d", "353d5817b86c4d61b58956bb50da83b6", "dc992e598e364004b435d0e7dd01f7c2", "b4f83ff147854a7b81fd9510981adc36", "08bb43115500409b99b4710dab7c5237", "c8b962a661a641a1bc6b633f5ca1fbe8", "5a240d0c09664383975038ac835f4f86", "b6c7a0c22fa2426bb9d3b9adcc465e67", "cd2ce3e0aead44ca9ffd2c8d00f473e6" ] }, "id": "tX7gYxZaVhYL", "outputId": "5cff0d05-0481-48ab-fe2a-6b34b917d198" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
/content/wandb/run-20231211_091655-4b6dc5p8
"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Step | \n", "Training Loss | \n", "
---|---|
50 | \n", "2.166900 | \n", "
100 | \n", "1.059100 | \n", "
150 | \n", "0.992300 | \n", "
200 | \n", "0.981700 | \n", "
250 | \n", "0.934600 | \n", "
300 | \n", "0.889800 | \n", "
"
]
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=300, training_loss=1.1707455444335937, metrics={'train_runtime': 778.6233, 'train_samples_per_second': 0.771, 'train_steps_per_second': 0.385, 'total_flos': 1552584749875200.0, 'train_loss': 1.1707455444335937, 'epoch': 0.3})"
]
},
"metadata": {},
"execution_count": 26
}
],
"source": [
"trainer.train() # wandb 가입해야함"
],
"id": "14019fa9-0c6f-4729-ac99-0d407af375b8"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "3Y4FQSyRghQt",
"outputId": "281bc4b2-077b-44ba-c27d-cc7017e2ffc6"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'/gdrive/MyDrive/lora-midm-7b-food-order-understanding'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 27
}
],
"source": [
"script_args.training_args.output_dir"
],
"id": "3Y4FQSyRghQt"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
"outputs": [],
"source": [
"trainer.save_model(script_args.training_args.output_dir)"
],
"id": "49f05450-da2a-4edd-9db2-63836a0ec73a"
},
{
"cell_type": "markdown",
"metadata": {
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
"source": [
"# 추론 테스트"
],
"id": "652f307e-e1d7-43ae-b083-dba2d94c2296"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
"outputs": [],
"source": [
"from transformers import pipeline, TextStreamer"
],
"id": "ea8a1fea-7499-4386-9dea-0509110f61af"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
"outputs": [],
"source": [
"instruction_prompt_template = \"\"\"\n",
"###System;너는 사용자의 리뷰를 긍정,부정 중 하나로만 판단해야 한다.\n",
"### 리뷰 문장: 진짜 재밌다 ### 분류 결과: 긍정\n",
"###System;너는 사용자의 리뷰를 긍정,부정 중 하나로만 판단해야 한다.\n",
"### 리뷰 문장: 나 잘 뻔 했잖아 영화보고 지루해서 ### 분류 결과: 부정\n",
"###System;너는 사용자의 리뷰를 긍정,부정 중 하나로만 판단해야 한다.\n",
"### 리뷰 문장: 어떻게 이렇게까지 재미없을 수가 있지 ### 분류 결과: 부정\n",
"###System;너는 사용자의 리뷰를 긍정,부정 중 하나로만 판단해야 한다.\n",
"### 리뷰 문장: 열린결말 영화 좋아하는데 이 영화가 열린결말이야 ### 분류 결과: 긍정\n",
"\n",
"\"\"\"\n",
"prompt_template = \"\"\"###System;{System}\n",
"###User;{User}\n",
"###Midm;\"\"\"\n",
"\n",
"default_system_msg = (\n",
" \"너는 사용자의 리뷰를 긍정,부정 중 하나로만 판단해야 한다.\"\n",
")"
],
"id": "52626888-1f6e-46b6-a8dd-836622149ff5"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "46e844fa-8f63-4359-a4fb-df66e8171796"
},
"outputs": [],
"source": [
"evaluation_queries = [\n",
" \"이게 재밌다는 사람들이 이해가 안가\"\n",
" \"너무 흥미로워요 시즌2도 나왔으면 좋겠어요\"\n",
" \"이게 무슨 영화야 지루하기짝이없네\"\n",
" \"배울점이 많은 영화네요\"\n",
"]"
],
"id": "46e844fa-8f63-4359-a4fb-df66e8171796"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
"outputs": [],
"source": [
"def wrapper_generate(model, input_prompt, do_stream=False):\n",
" data = tokenizer(input_prompt, return_tensors=\"pt\")\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
" input_ids = data.input_ids[..., :-1]\n",
" with torch.no_grad():\n",
" pred = model.generate(\n",
" input_ids=input_ids.cuda(),\n",
" streamer=streamer if do_stream else None,\n",
" use_cache=True,\n",
" max_new_tokens=float('inf'),\n",
" do_sample=False\n",
" )\n",
" decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n",
" decoded_text = decoded_text[0].replace(\"<[!newline]>\", \"\\n\")\n",
" return (decoded_text[len(input_prompt):])"
],
"id": "1919cf1f-482e-4185-9d06-e3cea1918416"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eaac1f6f-c823-4488-8edb-2f931ddf0daa",
"outputId": "ecf6fcc2-79fb-48bd-facc-e07bf1106116"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1473: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n",
" warnings.warn(\n"
]
}
],
"source": [
"eval_dic = {i:wrapper_generate(model=base_model, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"
],
"id": "eaac1f6f-c823-4488-8edb-2f931ddf0daa"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6",
"outputId": "e3b123dc-7e05-407e-9313-fe5fb3d6ab3a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"- 분석 결과 0: 음식명:치즈돈까스, 수량:한 판\n",
"- 분석 결과 1: 음식명:아메리카노, 옵션:아이스, 수량:한 잔\n"
]
}
],
"source": [
"print(eval_dic[0])"
],
"id": "fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6"
},
{
"cell_type": "markdown",
"metadata": {
"id": "3f471e3a-723b-4df5-aa72-46f571f6bab6"
},
"source": [
"# 미세튜닝된 모델 로딩 후 테스트"
],
"id": "3f471e3a-723b-4df5-aa72-46f571f6bab6"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
"outputs": [],
"source": [
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
") # save_model한 모델 체크포인트가져와서 로딩"
],
"id": "a43bdd07-7555-42b2-9888-a614afec892f"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 557
},
"id": "39db2ee4-23c8-471f-89b2-bca34964bf81",
"outputId": "36d043d6-c590-4d3f-adb0-6a67f2d5fc95"
},
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.