Upload Finetune_flan_t5_large_bnb_peft (1).ipynb
Browse files
Finetune_flan_t5_large_bnb_peft (1).ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"cells":[{"cell_type":"markdown","id":"lw1cWgq-DI5k","metadata":{"id":"lw1cWgq-DI5k"},"source":["# Fine-tune FLAN-T5 using `bitsandbytes`, `peft` & `transformers` ๐ค"]},{"cell_type":"markdown","id":"kBFPA3-aDT7H","metadata":{"id":"kBFPA3-aDT7H"},"source":["In this notebook we will see how to properly use `peft` , `transformers` & `bitsandbytes` to fine-tune `flan-t5-large` in a google colab!\n","\n","We will finetune the model on [`financial_phrasebank`](https://huggingface.co/datasets/financial_phrasebank) dataset, that consists of pairs of text-labels to classify financial-related sentences, if they are either `positive`, `neutral` or `negative`.\n","\n","Note that you could use the same notebook to fine-tune `flan-t5-xl` as well, but you would need to shard the models first to avoid CPU RAM issues on Google Colab, check [these weights](https://huggingface.co/ybelkada/flan-t5-xl-sharded-bf16)."]},{"cell_type":"markdown","source":["## TODO #1\n","\n","`google/flan-t5-large` ๋ชจ๋ธ์ ๋ฌด์์ ๋ชฉํ๋ก ๋ง๋ค์ด์ก๊ณ ๊ธฐ๋ํ ์ ์๋ ๊ธฐ๋ฅ์ ๋ฌด์์ธ์ง ์กฐ์ฌํ์์ค\n","- ๋งํฌ๋ค์ด ์คํ์ผ๋ก ์์ฑํ์์ค"],"metadata":{"id":"5TXx1vj8kJSu"},"id":"5TXx1vj8kJSu"},{"cell_type":"markdown","source":["## 'google/flan-t5-large' ๋ชจ๋ธ ๊ฐ์\n","\n","- 'google/flan-t5-large' ๋ชจ๋ธ์ T5 ์ํคํ
์ฒ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค.\n","- T5 ๋ชจ๋ธ์ ํ
์คํธ ์
๋ ฅ์ ๋ฐ์ ์ถ๋ ฅ์ ์์ฑํ๋ ์ํ์ค ํฌ ์ํ์ค ๋ชจ๋ธ๋ก, NLP ์์
์ ์ํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.\n","- ์ด ๋ชจ๋ธ์ \"๋ชจ๋ ๊ฒ์ ํ
์คํธ\"๋ผ๋ ์ ๊ทผ์ ๋ฐ๋ฅด๋ฉฐ ์
๋ ฅ ํ
์คํธ์ ์ถ๋ ฅ ํ
์คํธ๋ฅผ ๋์ผํ ํ์์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค.\n","\n","## ๊ธฐ๋ ๊ธฐ๋ฅ๊ณผ ํ์ฉ\n","\n","- 'google/flan-t5-large' ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ์ ๋ค์ํ NLP ์์
์ ์ํํ ์ ์์ต๋๋ค:\n"," - ํ
์คํธ ์์ฑ: ์
๋ ฅ ํ
์คํธ๋ก๋ถํฐ ๋ค์ํ ์ข
๋ฅ์ ํ
์คํธ๋ฅผ ์์ฑํฉ๋๋ค.\n"," - ์์ฝ: ๊ธด ๋ฌธ์๋ ํ
์คํธ๋ฅผ ๊ฐ๊ฒฐํ ์์ฝ์ผ๋ก ๋ณํํฉ๋๋ค.\n"," - ๋ฒ์ญ: ๋ค๊ตญ์ด ๋ฒ์ญ ์์
์ ์ํํ๋ฉฐ ์
๋ ฅ ํ
์คํธ๋ฅผ ๋ค๋ฅธ ์ธ์ด๋ก ๋ฒ์ญํฉ๋๋ค.\n"," - ์ง๋ฌธ ์๋ต: ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์์ฑํ๊ณ , ์ง๋ฌธ๊ณผ ์ง๋ฌธ์ ์ดํดํ์ฌ ๋ต๋ณํฉ๋๋ค.\n"," - ๋ฌธ์ฅ ๋ถ๋ฅ: ์ฃผ์ด์ง ๋ฌธ์ฅ์ ์นดํ
๊ณ ๋ฆฌ ๋๋ ํด๋์ค๋ก ๋ถ๋ฅํฉ๋๋ค.\n","\n","'google/flan-t5-large' ๋ชจ๋ธ์ ํตํด ๋ค์ํ NLP ์์
์ ์๋ํํ๊ณ ํฅ์์ํค๊ธฐ ์ํด, ๋ชจ๋ธ์ ํน์ ๊ธฐ๋ฅ๊ณผ ์์
์ ๋ฐ๋ฅธ ์ค์ ๋ฐ ๋ฐ์ดํฐ๊ฐ ํ์ํ๋ฉฐ, ์ด๋ฅผ ํตํด ์ ํํ๊ณ ํจ์จ์ ์ธ ์์ฐ์ด ์ฒ๋ฆฌ ์์
์ ์ํํ ์ ์์ต๋๋ค."],"metadata":{"id":"gNdrvxdIM83V"},"id":"gNdrvxdIM83V"},{"cell_type":"markdown","id":"ShAuuHCDDkvk","metadata":{"id":"ShAuuHCDDkvk"},"source":["## Install requirements"]},{"cell_type":"code","execution_count":null,"id":"DRQ4ZrJTDkSy","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DRQ4ZrJTDkSy","outputId":"3b98c09a-6889-4cdc-dddf-a7bb231b1f1d"},"outputs":[{"output_type":"stream","name":"stdout","text":[" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n"," Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n"," Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"]}],"source":["!pip install -q bitsandbytes datasets accelerate\n","!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main"]},{"cell_type":"markdown","id":"QBdCIrizDxFw","metadata":{"id":"QBdCIrizDxFw"},"source":["## Import model and tokenizer"]},{"cell_type":"code","execution_count":null,"id":"dd3c5acc","metadata":{"id":"dd3c5acc"},"outputs":[],"source":["# Select CUDA device index\n","import os\n","import torch\n","\n","os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n","\n","from datasets import load_dataset\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","model_name = \"google/flan-t5-large\"\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)\n","tokenizer = AutoTokenizer.from_pretrained(model_name)"]},{"cell_type":"markdown","id":"VwcHieQzD_dl","metadata":{"id":"VwcHieQzD_dl"},"source":["## Prepare model for training"]},{"cell_type":"markdown","id":"4o3ePxrjEDzv","metadata":{"id":"4o3ePxrjEDzv"},"source":["Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will:\n","- Casts all the non `int8` modules to full precision (`fp32`) for stability\n","- Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states\n","- Enable gradient checkpointing for more memory-efficient training"]},{"cell_type":"code","execution_count":null,"id":"1629ebcb","metadata":{"id":"1629ebcb"},"outputs":[],"source":["from peft import prepare_model_for_int8_training\n","\n","model = prepare_model_for_int8_training(model)"]},{"cell_type":"markdown","id":"iCpAgawAEieu","metadata":{"id":"iCpAgawAEieu"},"source":["## Load your `PeftModel`\n","\n","Here we will use LoRA (Low-Rank Adaptators) to train our model"]},{"cell_type":"code","execution_count":null,"id":"17566ae3","metadata":{"id":"17566ae3"},"outputs":[],"source":["from peft import LoraConfig, get_peft_model, TaskType\n","\n","\n","def print_trainable_parameters(model):\n"," \"\"\"\n"," Prints the number of trainable parameters in the model.\n"," \"\"\"\n"," trainable_params = 0\n"," all_param = 0\n"," for _, param in model.named_parameters():\n"," all_param += param.numel()\n"," if param.requires_grad:\n"," trainable_params += param.numel()\n"," print(\n"," f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n"," )\n","\n","\n","lora_config = LoraConfig(\n"," r=16, lora_alpha=32, target_modules=[\"q\", \"v\"], lora_dropout=0.05, bias=\"none\", task_type=\"SEQ_2_SEQ_LM\"\n",")\n","\n","\n","model = get_peft_model(model, lora_config)\n","print_trainable_parameters(model)"]},{"cell_type":"markdown","id":"mGkwIgNXyS7U","metadata":{"id":"mGkwIgNXyS7U"},"source":["As you can see, here we are only training 0.6% of the parameters of the model! This is a huge memory gain that will enable us to fine-tune the model without any memory issue."]},{"cell_type":"markdown","source":["## TODO #2\n","\n","์์ ๊ฐ์ด 0.6%๋ก ํ์ต ํ๋ผ๋ฏธํฐ์ ๊ฐฏ์๊ฐ ๋ํญ ์ถ์๋ ์๋ฆฌ์ ๋ํด ๊ฐ๋ต์ ์ผ๋ก ์กฐ์ฌํ์์ค.\n","- ๋งํฌ๋ค์ด ์คํ์ผ๋ก ์์ฑํ์์ค"],"metadata":{"id":"9kkyrzsakn2b"},"id":"9kkyrzsakn2b"},{"cell_type":"markdown","source":["## ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ํฌ๊ธฐ ์ถ์ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ ๊ฐ์ \n","\n","์ ๊ณต๋ ์ฝ๋์์ ์ฌ์ฉ๋ ๊ธฐ์ ์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์ ํฌ๊ธฐ๋ฅผ ์ถ์ํ๋ฉด์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๊ฐ์ ํ๊ณ ๋ชจ๋ธ์ ํ์ต ๊ฐ๋ฅํ ์ํ๋ก ์ ์งํ๋ ๋ฐฉ๋ฒ์
๋๋ค. ์ด ๊ธฐ์ ์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ด ์๋ ํ๊ฒฝ์์ ํจ๊ณผ์ ์ผ๋ก ๋ชจ๋ธ์ ํ์ฉํ๋ ๋ฐ ๋์์ ์ค๋๋ค.\n","\n","- **๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ถ๋ ฅ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ์ด๋**:\n"," - `print_trainable_parameters` ํจ์๋ ๋ชจ๋ธ์ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ถ๋ ฅํฉ๋๋ค.\n"," - ๊ฒฐ๊ณผ์์ \"trainable params\"๋ ์ค์ ๋ก ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์๋ฅผ ๋ํ๋
๋๋ค.\n"," - \"all params\"๋ ๋ชจ๋ธ์ ์ด ํ๋ผ๋ฏธํฐ ์๋ฅผ ๋ํ๋
๋๋ค.\n"," - \"trainable%\"์ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ์ ๋ฐฑ๋ถ์จ์ ๋ํ๋
๋๋ค.\n"," - ๊ฒฐ๊ณผ์์ \"trainable%\"๊ฐ ๋งค์ฐ ๋ฎ๊ฒ ๋ํ๋๋ฉด, ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ค ์ผ๋ถ๋ง์ด ํ์ต ๊ฐ๋ฅํ ์ํ๋ก ์ ์ง๋๊ณ , ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ํฌ๊ฒ ์ค์ด๋ค๊ฒ ๋ฉ๋๋ค.\n","\n","์ด ์ ๊ทผ ๋ฐฉ์์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ์ด ์๋ ํ๊ฒฝ์์ ํฐ ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ ํ๊ณ ์ ํ ๋ ํจ๊ณผ์ ์ด๋ฉฐ, ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํจ์จ์ ์ผ๋ก ํ์ฉํ ์ ์๊ฒ ํด์ค๋๋ค. ๋ํ ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ํ์ตํ๊ณ ์ฌ์ฉํ ์ ์๋๋ก ๋์์ค๋๋ค."],"metadata":{"id":"Yd8VN8RGNCmH"},"id":"Yd8VN8RGNCmH"},{"cell_type":"markdown","source":["## TODO #3\n","\n","๋ชจ๋ธ ๋ก๋์ `load_in_8bit=True` ์ต์
์ ์ฌ์ฉํ์ง ์์ผ๋ฉด ์๋ณธ์ ๋ก๋ฉํ๋ค.\n","\n","์ด ๋์ ๋ชจ๋ธ ๊ตฌ์กฐ์, `load_in_8bit=True` ์ ์ฌ์ฉํ์ ๋์ ๋ฌด๋ธ ๊ตฌ์กฐ๋ฅผ ๋น๊ตํ์ฌ ์ด๋ค ์ฐจ์ด์ ์ด ์๋์ง๋ฅผ ์กฐ์ฌํ์์ค.\n","- ๋งํฌ๋ค์ด ์คํ์ผ๋ก ์์ฑํ์์ค"],"metadata":{"id":"wgvqtHnFlNAl"},"id":"wgvqtHnFlNAl"},{"cell_type":"markdown","source":["## `load_in_8bit=True`์ `load_in_8bit=False` ๋ชจ๋ธ ๋ก๋ ์ต์
๋น๊ต\n","\n","`load_in_8bit=True` ์ต์
์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ก๋ํ๋ ๊ฒฝ์ฐ์ ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ ๋ชจ๋ธ ๊ตฌ์กฐ์ ์ฐจ์ด๊ฐ ์์ ์ ์์ผ๋ฉฐ, ์ฃผ์ ์ฐจ์ด์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:\n","\n","1. **๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์ ๋ฐ์ดํฐ ์ ํ**:\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ: ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ 8๋นํธ ์ ๋ฐ๋๋ก ์ ์ฅ๋ฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ํธํฅ์ ํํํ๋๋ฐ ์ฌ์ฉ๋๋ ์ซ์๊ฐ ์๋์ ์ผ๋ก ์์์ ์๋ฏธํ๋ฉฐ, ์ด๋ ๋ชจ๋ธ์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ๊ฒ ์ฌ์ฉํ๋ ์ฅ์ ์ ์ ๊ณตํฉ๋๋ค.\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ: ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ ์ผ๋ฐ์ ์ผ๋ก 32๋นํธ ๋๋ 16๋นํธ๋ก ์ ์ฅ๋ฉ๋๋ค. ์ด๋ ๋ชจ๋ธ์ ๊ฐ์ค์น์ ํธํฅ์ด ์๋์ ์ผ๋ก ํฐ ์ซ์๋ฅผ ๊ฐ์ง ์ ์์ผ๋ฉฐ, ์ด๋ก ์ธํด ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ์ด ์ฆ๊ฐํ ์ ์์ต๋๋ค.\n","\n","2. **๋ชจ๋ธ ๋ฉ๋ชจ๋ฆฌ ์๊ตฌ ์ฌํญ**:\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ: ๋ชจ๋ธ์ด ์ฌ์ฉํ๋ ๋ฉ๋ชจ๋ฆฌ ์์ด ๊ฐ์ํ๋ฏ๋ก ๋ ํจ์จ์ ์ผ๋ก ์๋ํ ์ ์์ต๋๋ค.\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ: ๋ชจ๋ธ์ด ์ฌ์ฉํ๋ ๋ฉ๋ชจ๋ฆฌ ์์ด ์ฆ๊ฐํ ์ ์์ต๋๋ค.\n","\n","3. **์ฑ๋ฅ ๋ฐ ์ ํ๋**:\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ: ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์ 8๋นํธ ์ ๋ฐ๋๋ก ์ธํด ๋ชจ๋ธ์ ์ฑ๋ฅ๊ณผ ์ ํ๋๊ฐ ๊ฐ์ํ ์๋ ์์ต๋๋ค. ์ด๋ก ์ธํด ์์ธก์ ์ ํ๋๊ฐ ์ ํ๋ ์ ์์ต๋๋ค.\n"," - `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ์ง ์์ ๊ฒฝ์ฐ: ์๋ณธ ์ ๋ฐ๋๋ก ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๊ฐ ๋ก๋๋๋ฏ๋ก ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๋ ๋์ ์ ์๏ฟฝ๏ฟฝ๏ฟฝ๋๋ค.\n","\n","๋ฐ๋ผ์ `load_in_8bit=True`๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ด ๊ฐ์ ๋์ง๋ง, ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๊ฐ์ํ ์ ์์ผ๋ฏ๋ก ์ ๋นํ๊ฒ ๊ณ ๋ คํด์ ์ฌ์ฉํด์ผ ํฉ๋๋ค."],"metadata":{"id":"m08rbbKxPAby"},"id":"m08rbbKxPAby"},{"cell_type":"markdown","id":"HsG0x6Z7FwjZ","metadata":{"id":"HsG0x6Z7FwjZ"},"source":["## Load and process data\n","\n","Here we will use [`financial_phrasebank`](https://huggingface.co/datasets/financial_phrasebank) dataset to fine-tune our model on sentiment classification on financial sentences. We will load the split `sentences_allagree`, which corresponds according to the model card to the split where there is a 100% annotator agreement."]},{"cell_type":"code","execution_count":null,"id":"242cdfae","metadata":{"id":"242cdfae"},"outputs":[],"source":["# loading dataset\n","dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n","dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n","dataset[\"validation\"] = dataset[\"test\"]\n","del dataset[\"test\"]\n","\n","classes = dataset[\"train\"].features[\"label\"].names\n","dataset = dataset.map(\n"," lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n"," batched=True,\n"," num_proc=1,\n",")"]},{"cell_type":"markdown","id":"qzwyi-Z9yzRF","metadata":{"id":"qzwyi-Z9yzRF"},"source":["Let's also apply some pre-processing of the input data, the labels needs to be pre-processed, the tokens corresponding to `pad_token_id` needs to be set to `-100` so that the `CrossEntropy` loss associated with the model will correctly ignore these tokens."]},{"cell_type":"code","execution_count":null,"id":"6b7ea44c","metadata":{"id":"6b7ea44c"},"outputs":[],"source":["# data preprocessing\n","text_column = \"sentence\"\n","label_column = \"text_label\"\n","max_length = 128\n","\n","\n","def preprocess_function(examples):\n"," inputs = examples[text_column]\n"," targets = examples[label_column]\n"," model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = tokenizer(targets, max_length=3, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n"," labels = labels[\"input_ids\"]\n"," labels[labels == tokenizer.pad_token_id] = -100\n"," model_inputs[\"labels\"] = labels\n"," return model_inputs\n","\n","\n","processed_datasets = dataset.map(\n"," preprocess_function,\n"," batched=True,\n"," num_proc=1,\n"," remove_columns=dataset[\"train\"].column_names,\n"," load_from_cache_file=False,\n"," desc=\"Running tokenizer on dataset\",\n",")\n","\n","train_dataset = processed_datasets[\"train\"]\n","eval_dataset = processed_datasets[\"validation\"]"]},{"cell_type":"markdown","source":["## TODO #4\n","\n","์ ๋ฐ์ดํฐ์
๋ก๋ฉ/๊ฐ๊ณต์์ ์ฌ์ฉํ ํ๋ธ์ ๋ฐ์ดํฐ์
`financial_phrasebank` ๊ตฌ์กฐ์ ์ด ์
์ด ์ด๋ป๊ฒ ๋ฏธ์ธํ๋์ ํ์ฉ๋์๋์ง ๊ฐ๋ต์ ์ผ๋ก ์กฐ์ฌํ์์ค.\n","- ๋งํฌ๋ค์ด ์คํ์ผ๋ก ์์ฑํ์์ค"],"metadata":{"id":"zmh21tjCm01z"},"id":"zmh21tjCm01z"},{"cell_type":"markdown","source":["## 'financial_phrasebank' ๋ฐ์ดํฐ์
์ ํ์ฉํ NLP ๋ชจ๋ธ ๋ฏธ์ธ ํ๋ ์์ \n","\n","์๋ ์ฝ๋์ ๊ฐ๋ต์ ์ธ ๊ฐ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:\n","\n","1. **๋ฐ์ดํฐ์
๋ก๋ฉ ๋ฐ ๊ฐ๊ณต**:\n"," - `financial_phrasebank` ๋ฐ์ดํฐ์
์ ๊ธ์ต ๊ด๋ จ ํ
์คํธ ๋ฐ์ดํฐ์ด๋ฉฐ Hugging Face Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํ์ฉํฉ๋๋ค.\n"," - ๋ฐ์ดํฐ๋ ํ์ต ๋ฐ ๊ฒ์ฆ ๋ฐ์ดํฐ์
์ผ๋ก ๋ถํ ๋๊ณ , ๋ ์ด๋ธ์ด ์ฒ๋ฆฌ๋์ด ๋ชจ๋ธ ํ์ต์ ๋ง๊ฒ ์ค๋น๋ฉ๋๋ค.\n","\n","2. **๋ชจ๋ธ ๋ฏธ์ธ ํ๋**:\n"," - `TrainingArguments`๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์ค์ ์ด ์ ์๋ฉ๋๋ค. ์ด ์ค์ ์ ํ์ต๋ฅ , ๋ฐฐ์น ํฌ๊ธฐ, ํ์ต ์ํฌํฌ, ์ ์ฅ ๋ฐ ํ๊ฐ ์ฃผ๊ธฐ ๋ฑ์ ์ค์ ํฉ๋๋ค.\n"," - `Trainer` ํด๋์ค๋ฅผ ํ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ฏธ์ธ ํ๋ํฉ๋๋ค. ์ด๋ ๋ชจ๋ธ, ํ์ต ์ค์ , ํ์ต ๋ฐ์ดํฐ์
๋ฐ ๊ฒ์ฆ ๋ฐ์ดํฐ์
์ด ์ฌ์ฉ๋ฉ๋๋ค.\n"," - `trainer.train()`์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ํ์ต์ํต๋๋ค.\n","\n","3. **๋ชจ๋ธ ์ถ๋ก **:\n"," - ํ์ต์ด ์๋ฃ๋ ๋ชจ๋ธ์ ํ๊ฐํ๊ณ ์ถ๋ก ํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค.\n"," - `model.eval()`์ ํธ์ถํ์ฌ ๋ชจ๋ธ์ ์ถ๋ก ๋ชจ๋๋ก ์ค์ ํ๊ณ , ์
๋ ฅ ๋ฌธ์ฅ์ด ์ ์๋ฉ๋๋ค.\n"," - ์
๋ ฅ ๋ฌธ์ฅ์ ํ ํฐํํ๊ณ ๋ชจ๋ธ์ ์ ๋ฌํ์ฌ ๋ชจ๋ธ์ ์ถ๋ ฅ์ ์์ฑํฉ๋๋ค.\n"," - ๋ชจ๋ธ์ ์ถ๋ ฅ์ ํด๋
ํ์ฌ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์ป์ต๋๋ค.\n","\n","4. **๊ฒฐ๊ณผ ์ถ๋ ฅ**:\n"," - ์
๋ ฅ ๋ฌธ์ฅ๊ณผ ๋ชจ๋ธ์ ์์ธก ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ๋ฉ๋๋ค.\n","\n","์ด ์ฝ๋๋ฅผ ํตํด ๊ธ์ต ๊ด๋ จ ํ
์คํธ ๋ฐ์ดํฐ์ ๋ํ NLP ๋ชจ๋ธ์ ๋ฏธ์ธ ํ๋ํ๊ณ ์ด๋ฅผ ํตํด ์ ํํ ์์ธก์ ์ํํ๋ ๊ฐ๋จํ ์์ ๊ฐ ์ ์๋ฉ๋๋ค. ๋ฏธ์ธ ํ๋์ ํตํด ๋ชจ๋ธ์ ํน์ ๋ฐ์ดํฐ์
๊ณผ ์์
์ ๋ ์ ์ ์ํ ์ ์์ผ๋ฉฐ, ์ด๋ ๋ ๋์ ์ฑ๋ฅ๊ณผ ์ ํ๋๋ฅผ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค."],"metadata":{"id":"PzXUprxPPbI9"},"id":"PzXUprxPPbI9"},{"cell_type":"markdown","id":"bcNTdVypGEPb","metadata":{"id":"bcNTdVypGEPb"},"source":["## Train our model!\n","\n","Let's now train our model, run the cells below.\n","Note that for T5 since some layers are kept in `float32` for stability purposes there is no need to call autocast on the trainer."]},{"cell_type":"code","execution_count":null,"id":"69c756ac","metadata":{"id":"69c756ac"},"outputs":[],"source":["from transformers import TrainingArguments, Trainer\n","\n","training_args = TrainingArguments(\n"," \"temp\",\n"," evaluation_strategy=\"epoch\",\n"," learning_rate=1e-3,\n"," gradient_accumulation_steps=1,\n"," auto_find_batch_size=True,\n"," num_train_epochs=1,\n"," save_steps=100,\n"," save_total_limit=8,\n",")\n","trainer = Trainer(\n"," model=model,\n"," args=training_args,\n"," train_dataset=train_dataset,\n"," eval_dataset=eval_dataset,\n",")\n","model.config.use_cache = False # silence the warnings. Please re-enable for inference!"]},{"cell_type":"code","execution_count":null,"id":"ab52b651","metadata":{"id":"ab52b651"},"outputs":[],"source":["trainer.train()"]},{"cell_type":"markdown","id":"r98VtofiGXtO","metadata":{"id":"r98VtofiGXtO"},"source":["## Qualitatively test our model"]},{"cell_type":"markdown","id":"NIm7z3UNzGPP","metadata":{"id":"NIm7z3UNzGPP"},"source":["Let's have a quick qualitative evaluation of the model, by taking a sample from the dataset that corresponds to a positive label. Run your generation similarly as you were running your model from `transformers`:"]},{"cell_type":"code","execution_count":null,"id":"c95d6173","metadata":{"id":"c95d6173"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]},{"cell_type":"markdown","source":["## TODO #5\n","\n","๋ณธ์ธ์ ํ๊น
ํ์ด์ค ๊ณ์ ์ ๋ง๋ค๊ณ ์๋ ํ๋ธ์ ์
๋ก๋/๋ค์ด๋ก๋/ํ์ธ ๊ณผ์ ์ ๋ณธ์ธ ๊ณ์ ๊ธฐ์ค์ผ๋ก ์งํํ์์ค.\n","\n","์งํ ํ ์
๋ฅด๋ํ ํ๊น
ํ์ด์ค ํ๋ธ์ ๋ชจ๋ธ id๋ฅผ ์ ์ผ์์ค.\n","- ๋งํฌ๋ค์ด ์คํ์ผ๋ก ์์ฑํ์์ค."],"metadata":{"id":"ubwn2Qdbl3Fb"},"id":"ubwn2Qdbl3Fb"},{"cell_type":"markdown","source":[],"metadata":{"id":"hK-Mdl4VgKcN"},"id":"hK-Mdl4VgKcN"},{"cell_type":"markdown","id":"9QqBlwzoGZ3f","metadata":{"id":"9QqBlwzoGZ3f"},"source":["## Share your adapters on ๐ค Hub"]},{"cell_type":"markdown","id":"NT-C8SjcKqUx","metadata":{"id":"NT-C8SjcKqUx"},"source":["Once you have trained your adapter, you can easily share it on the Hub using the method `push_to_hub` . Note that only the adapter weights and config will be pushed"]},{"cell_type":"code","execution_count":null,"id":"bcbfa1f9","metadata":{"id":"bcbfa1f9"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"]},{"cell_type":"code","execution_count":null,"id":"rFKJ4vHNGkJw","metadata":{"id":"rFKJ4vHNGkJw"},"outputs":[],"source":["model.push_to_hub(\"yysspp/flan-t5-large-financial-phrasebank-lora\", use_auth_token=True)"]},{"cell_type":"markdown","id":"xHuDmbCYJ89f","metadata":{"id":"xHuDmbCYJ89f"},"source":["## Load your adapter from the Hub"]},{"cell_type":"markdown","id":"ANFo6DdfKlU3","metadata":{"id":"ANFo6DdfKlU3"},"source":["You can load the model together with the adapter with few lines of code! Check the snippet below to load the adapter from the Hub and run the example evaluation!"]},{"cell_type":"code","execution_count":null,"id":"j097aaPWJ-9u","metadata":{"id":"j097aaPWJ-9u"},"outputs":[],"source":["import torch\n","from peft import PeftModel, PeftConfig\n","from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n","\n","peft_model_id = \"yysspp/flan-t5-large-financial-phrasebank-lora\"\n","config = PeftConfig.from_pretrained(peft_model_id)\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, torch_dtype=\"auto\", device_map=\"auto\")\n","tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n","\n","# Load the Lora model\n","model = PeftModel.from_pretrained(model, peft_model_id)"]},{"cell_type":"code","execution_count":null,"id":"jmjwWYt0KI_I","metadata":{"id":"jmjwWYt0KI_I"},"outputs":[],"source":["model.eval()\n","input_text = \"In January-September 2009 , the Group 's net interest income increased to EUR 112.4 mn from EUR 74.3 mn in January-September 2008 .\"\n","inputs = tokenizer(input_text, return_tensors=\"pt\")\n","\n","outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n","\n","print(\"input sentence: \", input_text)\n","print(\" output prediction: \", tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"]}],"metadata":{"accelerator":"GPU","colab":{"provenance":[],"gpuType":"T4","toc_visible":true},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"},"vscode":{"interpreter":{"hash":"1219a10c7def3e2ad4f431cfa6f49d569fcc5949850132f23800e792129eefbb"}}},"nbformat":4,"nbformat_minor":5}
|