{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning for multi-label text classification\n", "**Note: This notebook was run in Google Colab**\n", "\n", "This notebook demonstrates how to fine-tune a `bert-base-uncased` model using this Kaggle [dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge)\n", "\n", "The Colab link is here: https://colab.research.google.com/drive/1_tOvmArkigdQpxhZhzVIhR58InDHrxPz\n", "\n", "## Setup Environment\n", "We first install and import all the necessary libraries and modules." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wybe5jQM1NLf", "outputId": "2f86abba-83b8-4d89-dcf8-3dec710886e0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers\n", " Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m51.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m84.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n", "Collecting huggingface-hub<1.0,>=0.11.0\n", " Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m24.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2023.4.0)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", "Installing collected packages: tokenizers, huggingface-hub, transformers\n", "Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1\n" ] } ], "source": [ "!pip install transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "AYvuPa35Wq9C" }, "source": [ "---------------------------------------------------------" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hQN-HmXXW6SA" }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "from torch.utils.data import Dataset\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Dataset\n", "Here we extract training and validation datasets from `train.csv`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WtsAFyrzWuCr" }, "outputs": [], "source": [ "# Read dataset and extract all texts and labels\n", "df = pd.read_csv(\"/content/drive/MyDrive/AI_project/data/train.csv\")\n", "\n", "train_texts = df[\"comment_text\"].values\n", "labels = df.columns[2:]\n", "id2label = {idx:label for idx, label in enumerate(labels)}\n", "label2id = {label:idx for idx, label in enumerate(labels)}\n", "train_labels = df[labels].values\n", "\n", "# Randomly select training texts and respective labels\n", "np.random.seed(18)\n", "small_train_texts = np.random.choice(train_texts, size=25000, replace=False)\n", "\n", "np.random.seed(18)\n", "small_train_labels_idx = np.random.choice(train_labels.shape[0], size=25000, replace=False)\n", "small_train_labels = train_labels[small_train_labels_idx, :]\n", "\n", "# Split data into training data and validation data with a percentage of 80% vs 20%\n", "train_texts, val_texts, train_labels, val_labels = train_test_split(small_train_texts, small_train_labels, test_size=.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "As models like BERT don't expect text as direct input, but rather `input_ids`, etc., we tokenize the text using the tokenizer. The `AutoTokenizer` will automatically load the appropriate tokenizer based on the checkpoint on the hub. We can now merge the labels and texts to datasets as a class we defined." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "6e18d8e88c4c455f8c8acc8cf83ef38d", "669265aa86c74a35b257548449092362", "c191a7b285724d5ab7d962f36faebad6", "0785b592287643beb0546cc1ec7edd53", "14b87475585e46f48f33d967bd092fb1", "18382b77a8b949069144899ce4df2b7f", "c369d1887b8f480fb7281e7ac212537a", "d87c9eec619843e58160cf5ec79032d6", "c1629d36516647efa3e98e1cd71c6a74", "2b408fc48eae4330b2452da6c952ef08", "5bc464a714184d1586d0ae93e6044a34", "5d869754c9e94f999c55af68012caa22", "f7e2d957d5f941d1918fbf30754350d5", "3296da063e504b86b9e67e7d8baa4f96", "150d64b514bd49d99124ee45296de676", "4022514810d5404f93f8a90d683a6b84", "9eb1adc164a442b999d7b965ee3d3398", "1daff565766948a2ac40a9700eefd668", "a0d9a35c2f4b4f37bf5c9d06fb44d0f0", "1b1a0b6b1118444eafd071e7429d8a8e", "9c248ee9e4c74e48b881909532b71cf0", "98bf983675de497897154e358bdfa8ac", "0ec8ea8c9b02489295985bd6767cda5b", "25967142b0d841c4be3704affde1f3a0", "e0451cc5f6f74ddf94017bdd48817e05", "396b691570cf460181d46b3eb36a3b82", "17acfb2b88064fc683eb5bcb64389a04", "48336cf934eb45968604e9bb40e04e05", "37435accef354428b8b9e45a10a4b103", "aede7eeb0a714033b64105e3db96816c", "d50189d73a6543128e5a2d6134b8ab96", "0502a391770f4549b30801d241d0c7e9", "ce1dd99060824d3e95be42e89735965e", "09fee717abef4084bc9b1e1c765527a4", "ad300d0d30ce43a58788c79ff6a363d8", "95fdfe9dc41c4ca28ece249529b4ae3a", "de918c551d334f80a2db5077ccc9fd1c", "eb316534775846b5a755f62d48ff19aa", "7f789bfea7094505bf4f05a720dee7d8", "7ccd0fa35c324a1e8acb9f2bc2b672f9", "aa05e364db114294be272653c0b83e46", "e795934f49f2446cabc78be8c33dd8e6", "dfc93871fa744bdebff63b293e6cda52", "b6e07a8d5059495ca0b35446ebe0d923" ] }, "id": "pPgvgOaYXb2f", "outputId": "338106c8-da47-4d6f-d1bd-f21e383d8aab" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6e18d8e88c4c455f8c8acc8cf83ef38d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)okenizer_config.json: 0%| | 0.00/28.0 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d869754c9e94f999c55af68012caa22", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/570 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0ec8ea8c9b02489295985bd6767cda5b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "09fee717abef4084bc9b1e1c765527a4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/466k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aysAKCYoXBoz" }, "outputs": [], "source": [ "# Define a class of dataset for training\n", "class TextDataset(Dataset):\n", " def __init__(self,texts,labels):\n", " self.texts = texts\n", " self.labels = labels\n", "\n", " def __getitem__(self,idx):\n", " encodings = tokenizer(self.texts[idx], truncation=True, padding=\"max_length\")\n", " item = {key: torch.tensor(val) for key, val in encodings.items()}\n", " item['labels'] = torch.tensor(self.labels[idx],dtype=torch.float32)\n", " del encodings\n", " return item\n", "\n", " def __len__(self):\n", " return len(self.labels)\n", "\n", "train_dataset = TextDataset(train_texts,train_labels)\n", "val_dataset = TextDataset(val_texts, val_labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "1cfd2e5c43634f95a7c683d77e702e4e", "2f4b3f4201314c0e89287ba08acad2eb", "123b6fb2afb54b83bee0ed9c32c587a6", "76144bde33614175bf329e3ab523444b", "45326a0d9c0a4b139100a56d2ba4ee5e", "8f3f5c5290e54796b7d21b463ed1a498", "3c082e63a414465d87d9e07a858ec143", "428b70126f2c4dcabe14c15168d09fbe", "5c88385aafd04793bc266a8d5fda6fe6", "5930da2fec2b454b89b94f5ddf77958c", "1fb65c36ace1472b9fbe96476388ac47", "c0cc95d9874543b2896ad50140333d0b", "257cc9a22b7841efac3b8fdf8c05d4ad", "adac4554f514461dbf2d942798496f9d", "aa5a74efdeb9410a849c1de02cdac268", "5f111f85e0b34f29910c1fb7a949f336", "4e4ebbab3f9f4da9b01b9e58113ce977", "3513f9b40b5b4edd9cedd1fb0c54830e", "93d8e2f6a1b2490d8359ab26fee5f439", "231bd387432e4eaf82d87fd9e3dc1166", "3bbbe9a8a19d4745907ec258ed095cbe", "0d73b5679fe04663a3a623da4ad09f50", "0755924e37db4b0795afa748ef304daf", "98d97aee24994c6a9595cc2bfd2792d7", "b87c0b7c1ee0486cb8548847b2231b34", "2719f5d2eb824f3ba61a133ea2b95d15", "d2eae6efd0ba457887ef1d48b45dac15", "ffcdba14a68743708410f5011358831d", "fcca7364dab24b9c80aebefe28a17668" ] }, "id": "BcZnYYII3Nxo", "outputId": "4096dccf-8d0c-4892-9435-86c4be09cb75" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1cfd2e5c43634f95a7c683d77e702e4e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
1 | \n", "0.052500 | \n", "0.048165 | \n", "
2 | \n", "0.037000 | \n", "0.044507 | \n", "
3 | \n", "0.027500 | \n", "0.048948 | \n", "
4 | \n", "0.018800 | \n", "0.049080 | \n", "
5 | \n", "0.014600 | \n", "0.050677 | \n", "
"
],
"text/plain": [
"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.