{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "215a1aae", "metadata": { "id": "215a1aae" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-04-23 21:39:14.489766: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-04-23 21:39:15.104927: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import pandas as pd\n", "\n", "from transformers import BertTokenizerFast, BertForSequenceClassification\n", "from transformers import Trainer, TrainingArguments" ] }, { "cell_type": "code", "execution_count": 3, "id": "J5Tlgp4tNd0U", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "J5Tlgp4tNd0U", "outputId": "f2eef2ee-7d9d-4f5b-e35c-e6015e68f59e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model_name = \"bert-base-uncased\"\n", "tokenizer = BertTokenizerFast.from_pretrained(model_name)\n", "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n", "model = model.to(\"cuda:0\")\n", "max_len = 200\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"results\",\n", " num_train_epochs=1,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=64,\n", " warmup_steps=500,\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir=\"./logs\",\n", " logging_steps=10\n", " )\n", "\n", "# dataset class that inherits from torch.utils.data.Dataset\n", "\n", " \n", "class TokenizerDataset(Dataset):\n", " def __init__(self, strings):\n", " self.strings = strings\n", " \n", " def __getitem__(self, idx):\n", " return self.strings[idx]\n", " \n", " def __len__(self):\n", " return len(self.strings)\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "9969c58c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9969c58c", "outputId": "5933b10b-9ddb-4b67-b66b-589207bef2d3", "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " id comment_text \\\n", "0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n", "1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n", "2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n", "3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n", "4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n", "... ... ... \n", "159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n", "159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n", "159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n", "159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n", "159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n", "\n", " toxic severe_toxic obscene threat insult identity_hate \n", "0 0 0 0 0 0 0 \n", "1 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 \n", "... ... ... ... ... ... ... \n", "159566 0 0 0 0 0 0 \n", "159567 0 0 0 0 0 0 \n", "159568 0 0 0 0 0 0 \n", "159569 0 0 0 0 0 0 \n", "159570 0 0 0 0 0 0 \n", "\n", "[159571 rows x 8 columns]\n" ] } ], "source": [ "train_data = pd.read_csv(\"data/train.csv\")\n", "print(train_data)\n", "train_text = train_data[\"comment_text\"]\n", "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n", " \"obscene\", \"threat\", \n", " \"insult\", \"identity_hate\"]]\n", "\n", "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n", "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n", " \"toxic\", \"severe_toxic\", \n", " \"obscene\", \"threat\", \n", " \"insult\", \"identity_hate\"]]\n", "\n", "# data preprocessing\n", "\n", "\n", "\n", "train_text = train_text.values.tolist()\n", "train_labels = train_labels.values.tolist()\n", "test_text = test_text.values.tolist()\n", "test_labels = test_labels.values.tolist()\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "1n56TME9Njde", "metadata": { "id": "1n56TME9Njde" }, "outputs": [], "source": [ "# prepare tokenizer and dataset\n", "\n", "class TweetDataset(Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", " self.tok = tokenizer\n", " \n", " def __getitem__(self, idx):\n", "# print(idx)\n", " print(len(self.labels))\n", " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n", " print(encoding.items())\n", " item = { key: torch.tensor(val) for key, val in encoding.items() }\n", " item['labels'] = torch.tensor(self.labels[idx])\n", "# print(item)\n", " return item\n", " \n", " def __len__(self):\n", " return len(self.labels)\n", "\n", "# no tokenizer\n", "class TweetDataset2(Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", " self.tok = tokenizer\n", " \n", " def __getitem__(self, idx):\n", "# print(idx)\n", " print(len(self.labels))\n", " encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n", " print(encoding.items())\n", " item = { key: torch.tensor(val) for key, val in encoding.items() }\n", " item['labels'] = torch.tensor(self.labels[idx])\n", "# print(item)\n", " return item\n", " \n", " def __len__(self):\n", " return len(self.labels)\n", "\n", "\n", "\n", "\n", "train_strings = TokenizerDataset(train_text)\n", "test_strings = TokenizerDataset(test_text)\n", "\n", "train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n", "test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n", "\n", "\n", "\n", "\n", "train_encodings = tokenizer.batch_encode_plus(train_text, \\\n", " max_length=200, pad_to_max_length=True, \\\n", " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n", " ).to(\"cuda:0\")\n", "test_encodings = tokenizer.batch_encode_plus(test_text, \\\n", " max_length=200, pad_to_max_length=True, \\\n", " truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n", " ).to(\"cuda:0\")\n", "\n", "# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n", "# test_encodings = tokenizer(test_text, truncation=True, padding=True)" ] }, { "cell_type": "code", "execution_count": 15, "id": "4kwydz67qjW9", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4kwydz67qjW9", "outputId": "1653744e-69cf-46f8-a2d1-ffc3a3a4d58a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "159571\n", "159571\n" ] } ], "source": [ "# no tokenizer\n", "class TweetDataset3(Dataset):\n", " def __init__(self, encodings, labels):\n", " self.encodings = encodings\n", " self.labels = labels\n", " self.tok = tokenizer\n", " \n", " def __getitem__(self, idx):\n", " print(idx)\n", " item = { key: torch.tensor(val) for key, val in self.encodings.items() }\n", " item['labels'] = torch.tensor(self.labels[idx])\n", "# print(item)\n", " return item\n", " \n", " def __len__(self):\n", " return len(self.labels)\n", "\n", "\n", "\n", "train_dataset = TweetDataset3(train_encodings, train_labels)\n", "test_dataset = TweetDataset3(test_encodings, test_labels)\n", "\n", "print(len(train_dataset.labels))\n", "print(len(train_strings))\n", "\n", "\n", "class MultilabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "# training\n", "trainer = MultilabelTrainer(\n", " model=model, \n", " args=training_args, \n", " train_dataset=train_dataset, \n", " eval_dataset=test_dataset\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "VwsyMZg_tgTg", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "VwsyMZg_tgTg", "outputId": "6cf8f3aa-629e-4650-9bbd-dfeb11071ef7" }, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }