{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "from datasets import load_dataset\n", "from transformers import DistilBertForSequenceClassification, DistilBertTokenizer\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.metrics import accuracy_score, f1_score\n", "\n", "import os\n", "import pickle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we are operating on a Mac with M2 chip, CUDA is not available. However, we can get GPU acceleration like this:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA? False\n", "MPS available? True\n", "MPS built? True\n", "Device: mps\n" ] } ], "source": [ "print(\"CUDA? \", torch.cuda.is_available())\n", "\n", "print(\"MPS available? \", torch.backends.mps.is_available()) #the MacOS is higher than 12.3+\n", "print(\"MPS built? \", torch.backends.mps.is_built()) #MPS is activated\n", "\n", "device = torch.device(\"mps\")\n", "print(\"Device: \", device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# EDA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will load the dataset: https://huggingface.co/datasets/dair-ai/emotion" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No config specified, defaulting to: emotion/split\n", "Found cached dataset emotion (/Users/david/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "861120cd48de4646996c7bc8e7daf92a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabellabel_names
0i didnt feel humiliated0sadness
1i can go from feeling so hopeless to so damned...0sadness
2im grabbing a minute to post i feel greedy wrong3anger
3i am ever feeling nostalgic about the fireplac...2love
4i am feeling grouchy3anger
............
15995i just had a very brief time in the beanbag an...0sadness
15996i am now turning and i feel pathetic that i am...0sadness
15997i feel strong and good overall1joy
15998i feel like this was such a rude comment and i...3anger
15999i know a lot but i feel so stupid because i ca...0sadness
\n", "

16000 rows × 3 columns

\n", "" ], "text/plain": [ " text label label_names\n", "0 i didnt feel humiliated 0 sadness\n", "1 i can go from feeling so hopeless to so damned... 0 sadness\n", "2 im grabbing a minute to post i feel greedy wrong 3 anger\n", "3 i am ever feeling nostalgic about the fireplac... 2 love\n", "4 i am feeling grouchy 3 anger\n", "... ... ... ...\n", "15995 i just had a very brief time in the beanbag an... 0 sadness\n", "15996 i am now turning and i feel pathetic that i am... 0 sadness\n", "15997 i feel strong and good overall 1 joy\n", "15998 i feel like this was such a rude comment and i... 3 anger\n", "15999 i know a lot but i feel so stupid because i ca... 0 sadness\n", "\n", "[16000 rows x 3 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emotions.set_format(\"pandas\")\n", "df = emotions[\"train\"][:]\n", "\n", "label_dict = {\n", " id:label for id, label in enumerate(['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'])\n", " }\n", "\n", "df[\"label_names\"] = df[\"label\"].map(lambda x: label_dict[x])\n", "\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can check class imbalances (and resample if necessary) and also text lengths." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "joy 5362\n", "sadness 4666\n", "anger 2159\n", "fear 1937\n", "love 1304\n", "surprise 572\n", "Name: label_names, dtype: int64\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(df.label_names.value_counts())\n", "sns.histplot(df.label_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We might consider upsampling for the minority classes if the performance is hindered!" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df[\"word_count\"] = df.text.apply(lambda x: len(x.split()))\n", "\n", "df.boxplot(column=\"word_count\", by=\"label_names\", showfliers=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no tweet longer than 70 words - this easily fits into the context size of our model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "emotions.reset_format()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading Pre-Trained Models\n", "\n", "https://huggingface.co/distilbert-base-uncased" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model_checkpoint = \"distilbert-base-uncased\"\n", "num_labels = 6\n", "\n", "tokenizer = DistilBertTokenizer.from_pretrained(model_checkpoint)\n", "model = DistilBertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DistilBertTokenizer(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DistilBertForSequenceClassification(\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0-5): 6 x TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (pre_classifier): Linear(in_features=768, out_features=768, bias=True)\n", " (classifier): Linear(in_features=768, out_features=6, bias=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", ")" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We place the model on the GPU:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mps:0\n" ] } ], "source": [ "model.to(device)\n", "print(model.device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can tokenize and detokenize a single data point like this:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i didnt feel humiliated\n", "{'input_ids': [101, 1045, 2134, 2102, 2514, 26608, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}\n", "['[CLS]', 'i', 'didn', '##t', 'feel', 'humiliated', '[SEP]']\n", "[CLS] i didnt feel humiliated [SEP]\n" ] } ], "source": [ "ex_text_orig = train_ds[0]['text']\n", "print(ex_text_orig)\n", "\n", "ex_encoded_text = tokenizer(train_ds[0]['text'])\n", "print(ex_encoded_text)\n", "\n", "ex_tokens = tokenizer.convert_ids_to_tokens(ex_encoded_text.input_ids)\n", "print(ex_tokens)\n", "\n", "ex_string_recon = tokenizer.convert_tokens_to_string(ex_tokens)\n", "print(ex_string_recon)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running a forward pass through the model with a batch of tokenized data points:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': tensor([[ 101, 1045, 2134, ..., 0, 0, 0],\n", " [ 101, 1045, 2064, ..., 0, 0, 0],\n", " [ 101, 10047, 9775, ..., 0, 0, 0],\n", " ...,\n", " [ 101, 1045, 8823, ..., 0, 0, 0],\n", " [ 101, 1045, 2514, ..., 0, 0, 0],\n", " [ 101, 1045, 2211, ..., 0, 0, 0]], device='mps:0'), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", " [1, 1, 1, ..., 0, 0, 0],\n", " [1, 1, 1, ..., 0, 0, 0],\n", " ...,\n", " [1, 1, 1, ..., 0, 0, 0],\n", " [1, 1, 1, ..., 0, 0, 0],\n", " [1, 1, 1, ..., 0, 0, 0]], device='mps:0')}\n", "SequenceClassifierOutput(loss=None, logits=tensor([[ 1.4233e-01, -3.5799e-02, -3.9052e-02, -1.2486e-01, 1.1337e-02,\n", " 4.1456e-03],\n", " [ 1.5962e-01, -1.8645e-03, -5.0531e-02, -1.1383e-01, 1.6687e-02,\n", " -1.2262e-02],\n", " [ 1.4096e-01, -2.8270e-02, -2.3139e-02, -1.4966e-01, -5.5413e-03,\n", " -3.0725e-04],\n", " [ 1.2116e-01, -2.1436e-02, -2.2654e-02, -1.1917e-01, 8.3146e-03,\n", " 1.2819e-02],\n", " [ 1.2686e-01, -2.5140e-03, -2.4941e-02, -1.6054e-01, 1.3207e-02,\n", " -2.0017e-03],\n", " [ 1.3553e-01, -4.4948e-02, -1.3296e-02, -1.1402e-01, -1.5227e-02,\n", " 3.8793e-02],\n", " [ 1.5724e-01, 6.9897e-02, -3.3859e-02, -1.2927e-01, -1.4081e-02,\n", " 2.3798e-02],\n", " [ 1.2619e-01, -5.0330e-03, -4.7603e-02, -1.2175e-01, -7.5088e-03,\n", " 1.0555e-02],\n", " [ 8.0060e-02, -3.2129e-02, -2.3077e-02, -1.6091e-01, 2.7138e-02,\n", " -7.4920e-03],\n", " [ 1.3315e-01, -4.5875e-02, -5.3141e-02, -1.4352e-01, -1.4877e-03,\n", " 1.2945e-02],\n", " [ 1.4902e-01, -1.4755e-02, -3.5898e-02, -1.3867e-01, -1.7973e-02,\n", " -1.0358e-02],\n", " [ 1.2576e-01, 6.0104e-04, -4.7409e-02, -1.2226e-01, 4.5271e-02,\n", " -1.5301e-02],\n", " [ 1.4248e-01, -2.0736e-02, -7.4252e-02, -1.3979e-01, 2.3742e-02,\n", " 7.6523e-03],\n", " [ 1.4122e-01, 1.4110e-02, -6.2735e-02, -1.4574e-01, -1.3873e-03,\n", " 2.6653e-03],\n", " [ 1.1123e-01, -7.4643e-03, -4.2686e-02, -1.1611e-01, 4.3666e-04,\n", " -9.3827e-03],\n", " [ 1.3886e-01, -1.5303e-03, -2.6063e-02, -1.4838e-01, 1.1344e-02,\n", " 1.5452e-02],\n", " [ 1.4855e-01, -5.7397e-02, -3.0159e-02, -1.3209e-01, 2.1942e-02,\n", " 9.1073e-03],\n", " [ 1.3703e-01, -2.3681e-02, -5.6169e-02, -1.3673e-01, 6.2913e-03,\n", " 1.8533e-02],\n", " [ 1.4937e-01, -1.9236e-02, -3.5975e-02, -9.6610e-02, -2.9679e-02,\n", " 8.2103e-02],\n", " [ 1.0979e-01, -1.7023e-02, -3.4062e-02, -1.4152e-01, -1.6392e-02,\n", " -2.3875e-02],\n", " [ 1.5651e-01, -1.8269e-03, -2.9920e-02, -1.3488e-01, -1.1655e-02,\n", " 3.7166e-03],\n", " [ 1.0090e-01, -1.8375e-03, -2.4937e-02, -1.2581e-01, 2.1587e-02,\n", " -2.9232e-05],\n", " [ 1.1796e-01, -3.7179e-02, -4.9161e-02, -1.3596e-01, 1.5601e-02,\n", " 1.0488e-02],\n", " [ 9.5583e-02, -3.4543e-02, -8.2448e-03, -1.1330e-01, -9.7575e-04,\n", " -1.9508e-02],\n", " [ 1.7488e-01, -6.2767e-03, -5.8071e-02, -1.0974e-01, 1.7336e-02,\n", " -1.2677e-02],\n", " [ 1.3011e-01, 7.0154e-03, -4.6719e-02, -1.2390e-01, 1.4996e-02,\n", " 5.8112e-03],\n", " [ 1.6171e-01, 2.8571e-02, 1.4520e-02, -1.1348e-01, 1.6719e-02,\n", " 4.6525e-02],\n", " [ 1.2096e-01, -4.7465e-03, -5.3050e-02, -1.2936e-01, 2.9222e-03,\n", " 4.4149e-03],\n", " [ 1.2663e-01, 2.1371e-02, -1.5378e-02, -1.4178e-01, -4.9094e-03,\n", " 2.5740e-02],\n", " [ 1.4949e-01, 4.5682e-02, -4.6581e-02, -1.1382e-01, 3.8925e-02,\n", " 1.8925e-02],\n", " [ 1.5930e-01, -2.6000e-02, -8.1725e-02, -1.2844e-01, 2.3078e-02,\n", " 1.7999e-02],\n", " [ 1.1508e-01, -5.4361e-02, -1.2097e-02, -1.1560e-01, -4.3136e-03,\n", " 1.6119e-02],\n", " [ 1.0650e-01, -2.9456e-02, -9.4272e-03, -1.4746e-01, -3.1861e-02,\n", " 1.9064e-02],\n", " [ 1.0366e-01, -9.7240e-03, -1.0911e-02, -1.2217e-01, -2.4717e-03,\n", " 1.3728e-02],\n", " [ 1.4861e-01, -4.1146e-02, -4.7642e-02, -1.2031e-01, 4.4020e-02,\n", " 1.0007e-03],\n", " [ 1.4028e-01, -5.5046e-02, -4.0656e-02, -1.2569e-01, -1.4845e-02,\n", " -9.3913e-03],\n", " [ 9.7403e-02, -1.8497e-02, -4.1954e-02, -1.5652e-01, -1.8027e-02,\n", " -1.0393e-02],\n", " [ 1.1734e-01, 2.1484e-03, -3.6896e-02, -1.3824e-01, 2.9927e-03,\n", " -9.7712e-03],\n", " [ 1.3530e-01, -4.1382e-03, -6.6913e-02, -1.2150e-01, 5.1750e-02,\n", " 5.2888e-03],\n", " [ 1.0634e-01, 1.9518e-02, -2.7008e-02, -1.1104e-01, 1.6907e-02,\n", " -2.3599e-02],\n", " [ 1.5376e-01, 7.4966e-03, -6.2959e-02, -1.0318e-01, -9.9983e-03,\n", " 3.7106e-02],\n", " [ 1.3008e-01, 2.0290e-02, -8.0737e-02, -1.4428e-01, 1.4979e-02,\n", " 2.6109e-03],\n", " [ 1.4071e-01, -1.4517e-02, -2.0958e-02, -1.1339e-01, 1.2747e-02,\n", " 9.3763e-03],\n", " [ 1.4183e-01, -9.8491e-03, -8.4350e-02, -1.4291e-01, 4.4465e-02,\n", " 9.4275e-04],\n", " [ 1.1117e-01, 2.8014e-02, -3.6789e-02, -1.2835e-01, 2.3733e-02,\n", " 2.4466e-02],\n", " [ 1.0633e-01, -3.3774e-02, -4.5758e-02, -1.3254e-01, -4.0638e-02,\n", " 3.3512e-02],\n", " [ 1.3389e-01, 1.6373e-02, -6.2875e-03, -1.2495e-01, 3.4706e-02,\n", " -1.7554e-02],\n", " [ 1.3963e-01, -2.7059e-02, -4.0488e-02, -1.0715e-01, -5.6545e-03,\n", " 2.7460e-03],\n", " [ 1.4890e-01, -4.0927e-02, -5.2144e-02, -7.0175e-02, 1.4726e-02,\n", " 4.2442e-02],\n", " [ 1.3726e-01, -3.9996e-02, -4.6873e-02, -1.4888e-01, 2.3211e-02,\n", " 3.4968e-02],\n", " [ 1.5906e-01, -4.4825e-02, -4.0935e-02, -1.3097e-01, 1.0965e-02,\n", " -1.5885e-02],\n", " [ 1.3486e-01, -4.9403e-02, -2.6408e-02, -1.2229e-01, 6.2970e-04,\n", " 3.2391e-02],\n", " [ 1.2863e-01, -5.5887e-03, -5.2600e-02, -1.3157e-01, -2.0005e-02,\n", " -1.3253e-02],\n", " [ 1.2188e-01, -6.8094e-02, -4.3069e-02, -1.0393e-01, -8.7743e-03,\n", " 7.4693e-03],\n", " [ 1.3216e-01, 3.3894e-02, -1.4873e-02, -7.0674e-02, 5.6290e-02,\n", " -8.4897e-03],\n", " [ 1.5045e-01, 4.7135e-03, -6.9442e-02, -1.3989e-01, 3.7274e-02,\n", " -2.3230e-02],\n", " [ 1.4205e-01, 1.5642e-02, -2.2888e-02, -1.3344e-01, 3.8608e-02,\n", " 6.5644e-03],\n", " [ 1.2683e-01, -2.0757e-02, -7.0946e-02, -1.3669e-01, 4.3046e-03,\n", " 1.5445e-02],\n", " [ 1.1804e-01, -6.8366e-03, -3.0480e-02, -1.3402e-01, 4.3425e-02,\n", " -3.6347e-03],\n", " [ 1.2283e-01, -4.2327e-02, -6.3867e-02, -1.2461e-01, -1.0048e-02,\n", " 2.0214e-02],\n", " [ 1.0351e-01, -3.1913e-03, -4.3201e-02, -1.5012e-01, 3.7961e-02,\n", " 1.6582e-02],\n", " [ 9.2689e-02, -2.0694e-02, 2.8925e-03, -1.4506e-01, -2.1509e-02,\n", " 2.4144e-02],\n", " [ 1.3375e-01, 6.9467e-03, -6.7183e-02, -1.3332e-01, -8.7235e-03,\n", " 3.4065e-02],\n", " [ 1.1857e-01, 4.5228e-02, -7.0513e-03, -1.3732e-01, -1.3650e-02,\n", " 4.4032e-02]], device='mps:0', grad_fn=), hidden_states=None, attentions=None)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/david/miniforge3/envs/hugging/lib/python3.9/site-packages/torch/_tensor_str.py:115: UserWarning: MPS: nonzero op is supported natively starting from macOS 13.0. Falling back on CPU. This may have performance implications. (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343673238/work/aten/src/ATen/native/mps/operations/Indexing.mm:218.)\n", " nonzero_finite_vals = torch.masked_select(\n" ] } ], "source": [ "ex_model_inputs = tokenizer(train_ds[:64]['text'], return_tensors=\"pt\", padding=True).to(device)\n", "ex_model_outputs = model(**ex_model_inputs)\n", "\n", "print(ex_model_inputs)\n", "print(ex_model_outputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Keep in mind that the output are logits, not probabilities!" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.0419, -0.0022, -0.0660, -0.0210, -0.0499, -0.0132, 0.0737, -0.0452,\n", " -0.1164, -0.0979, -0.0686, -0.0133, -0.0609, -0.0519, -0.0640, -0.0103,\n", " -0.0400, -0.0547, 0.0500, -0.1231, -0.0181, -0.0301, -0.0783, -0.0810,\n", " 0.0055, -0.0127, 0.1546, -0.0589, 0.0117, 0.0926, -0.0358, -0.0552,\n", " -0.0926, -0.0279, -0.0155, -0.1053, -0.1480, -0.0624, -0.0002, -0.0189,\n", " 0.0222, -0.0571, 0.0140, -0.0499, 0.0222, -0.1129, 0.0362, -0.0380,\n", " 0.0428, -0.0403, -0.0626, -0.0302, -0.0944, -0.0945, 0.1283, -0.0401,\n", " 0.0465, -0.0818, -0.0135, -0.0978, -0.0385, -0.0675, -0.0345, 0.0498],\n", " device='mps:0', grad_fn=)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.sum(ex_model_outputs.logits, dim=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Creating DataLoaders" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class TweetDataset(Dataset):\n", " def __init__(self, split):\n", " self.text = split[\"text\"]\n", " self.labels = split[\"label\"]\n", "\n", " def __len__(self):\n", " return len(self.text)\n", " \n", " def __getitem__(self, index):\n", " return self.text[index], self.labels[index]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 16\n", "\n", "train_dataset = TweetDataset(train_ds)\n", "val_dataset = TweetDataset(val_ds)\n", "test_dataset = TweetDataset(test_ds)\n", "\n", "train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(('i feel angered and firey',\n", " 'i feel a tender compassion glancing at her huge and heavy rucksack',\n", " 'ive been hanging around younger people and when i am with them i feel like im but when i see the photos of us together i am suddenly shaken to see just how old i look',\n", " 'i sometimes feel shy about my musical taste because some of it wanders towards what some might call techno slander',\n", " 'i woke up today feeling just as thankful',\n", " 'i feel a recipe is only a theme which an intelligent cook can play each time with a variation',\n", " 'i am feeling very anxious about going to therapy w',\n", " 'i no longer feel happy to score well',\n", " 'im sitting at the spare desk feeling totally disillusioned and frustrated with my working life in general',\n", " 'i am feeling peaceful yet determined as i listen to the slight humming noise of the ceiling fan',\n", " 'i see myself starting to feel the emotional dependence on my parents i stop and breathe',\n", " 'i don t mean this to be a serious recollection of feelings only a funny in a not funny sort of way story so let s get back to where the action begins',\n", " 'ive been feeling so anxious and nauseous and tired but also so elated that some nights its all i can do to crawl into bed',\n", " 'im feeling craving theres always a tender morsel of a song ready to appease my appetite',\n", " 'i am so sorry you are feeling so sad',\n", " 'i go up to her and i say feeling very impressed with myself youre naomi klein right'),\n", " tensor([3, 2, 4, 4, 1, 1, 4, 1, 0, 1, 0, 5, 4, 2, 0, 5]))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ex_batch_text, ex_batch_labels = next(iter(train_data_loader)) \n", "ex_batch_text, ex_batch_labels" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[3, 2, 4, 4, 1, 1, 4, 1, 0, 1, 0, 5, 4, 2, 0, 5]" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ex_batch_labels.tolist()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define Training Loop\n", "\n", "We could run a `Trainer` object from the `transformers` library but I want to implement the loop from scratch as practice. Furthermore, the logging of training metrics is implemented on a low level, instead of using `Tensorboard` and `SummaryWriter`." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=2e-5, weight_decay=0.01)\n", "\n", "def compute_metrics(logit_outputs, labels):\n", " preds = torch.max(F.softmax(logit_outputs, dim=-1), dim=-1)[1]\n", " acc = accuracy_score(labels, preds)\n", " f1 = f1_score(labels, preds, average='weighted')\n", " return acc, f1" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "EPOCHS = 3\n", "\n", "def train(\n", " model=model,\n", " train_data_loader=train_data_loader,\n", " val_data_loader=val_data_loader,\n", " tokenizer=tokenizer,\n", " optimizer=optimizer,\n", " loss_fn=loss_fn,\n", " epochs=EPOCHS):\n", " \n", " logs = {\n", " 'train_losses': [],\n", " 'train_metrics': [],\n", " 'val_losses': [],\n", " 'val_metrics': []\n", " }\n", "\n", " for epoch in range(EPOCHS):\n", " print(f\"--------------- EPOCH {epoch+1} ----------------\")\n", "\n", " # TRAIN\n", " train_loss_epoch = 0.0\n", " train_loss_iterations = []\n", " train_metrics_iterations = {'acc': [], 'f1': []}\n", " model.train()\n", " for i, train_batch in enumerate(train_data_loader):\n", " # algo \n", " optimizer.zero_grad()\n", " inputs, labels = train_batch\n", " inputs = tokenizer(inputs, return_tensors='pt', padding=True, truncation=True)\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", " outputs = model(**inputs)\n", " loss = loss_fn(outputs.logits, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # logging\n", " train_loss_epoch += loss * inputs.input_ids.size()[0]\n", " train_loss_iterations.append(loss)\n", " acc, f1 = compute_metrics(logit_outputs=outputs.logits.cpu(), labels=labels.cpu())\n", " train_metrics_iterations['acc'].append(acc)\n", " train_metrics_iterations['f1'].append(f1)\n", " \n", " # free GPU memory\n", " inputs.input_ids.detach()\n", " inputs.attention_mask.detach()\n", " labels.detach()\n", "\n", " train_loss_epoch /= len(train_data_loader.dataset)\n", " train_acc_epoch = sum(train_metrics_iterations['acc']) / len(train_data_loader)\n", " train_f1_epoch = sum(train_metrics_iterations['f1']) / len(train_data_loader)\n", " print(f\"Average TRAIN loss: {train_loss_epoch}\")\n", " print(f\"Average TRAIN acc: {train_acc_epoch}\")\n", " print(f\"Average TRAIN F1: {train_f1_epoch}\")\n", " print()\n", " logs[\"train_losses\"].append(train_loss_iterations)\n", " logs[\"train_metrics\"].append(train_metrics_iterations)\n", "\n", "\n", " # VALIDATE\n", " val_loss_epoch = 0.0 \n", " val_loss_iterations = []\n", " val_metrics_iterations = {'acc': [], 'f1': []}\n", " model.eval()\n", " with torch.no_grad():\n", " for val_batch in val_data_loader:\n", " inputs, labels = val_batch\n", " inputs = tokenizer(inputs, return_tensors='pt', padding=True, truncation=True)\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", " outputs = model(**inputs)\n", " loss = loss_fn(outputs.logits, labels)\n", "\n", " val_loss_epoch += loss * inputs.input_ids.size()[0]\n", " val_loss_iterations.append(loss)\n", " acc, f1 = compute_metrics(logit_outputs=outputs.logits.cpu(), labels=labels.cpu())\n", " val_metrics_iterations['acc'].append(acc)\n", " val_metrics_iterations['f1'].append(f1)\n", "\n", " inputs.input_ids.detach()\n", " inputs.attention_mask.detach()\n", " labels.detach()\n", "\n", " val_loss_epoch /= len(val_data_loader.dataset)\n", " val_acc_epoch = sum(val_metrics_iterations['acc']) / len(val_data_loader)\n", " val_f1_epoch = sum(val_metrics_iterations['f1']) / len(val_data_loader)\n", " print(f\"Average VALIDATION loss: {val_loss_epoch}\")\n", " print(f\"Average VALIDATION acc: {val_acc_epoch}\")\n", " print(f\"Average VALIDATION F1: {val_f1_epoch}\")\n", " print()\n", " logs[\"val_losses\"].append(val_loss_iterations)\n", " logs[\"val_metrics\"].append(val_metrics_iterations)\n", " \n", " # CHECKPOINT\n", " torch.save(model, f\"checkpoints/model_epoch_{epoch+1}\")\n", " \n", " return logs" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------- EPOCH 1 ----------------\n", "Average TRAIN loss: 0.5427333116531372\n", "Average TRAIN acc: 0.810875\n", "Average TRAIN F1: 0.7868587813738663\n", "\n", "Average VALIDATION loss: 0.19164295494556427\n", "Average VALIDATION acc: 0.9295\n", "Average VALIDATION F1: 0.9345529768992618\n", "\n", "--------------- EPOCH 2 ----------------\n", "Average TRAIN loss: 0.17533010244369507\n", "Average TRAIN acc: 0.9329375\n", "Average TRAIN F1: 0.9323326799152223\n", "\n", "Average VALIDATION loss: 0.15689700841903687\n", "Average VALIDATION acc: 0.9355\n", "Average VALIDATION F1: 0.9365527271447162\n", "\n", "--------------- EPOCH 3 ----------------\n", "Average TRAIN loss: 0.16286173462867737\n", "Average TRAIN acc: 0.9343125\n", "Average TRAIN F1: 0.9334694767618423\n", "\n", "Average VALIDATION loss: 0.17313864827156067\n", "Average VALIDATION acc: 0.93\n", "Average VALIDATION F1: 0.9266658026613909\n", "\n" ] } ], "source": [ "history = train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It looks like we start to overfit within epoch 3. For further evaluation on the test set, we will load the checkpoint after epoch 2." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "losses_by_iteration = []\n", "for epoch in history[\"train_losses\"]:\n", " for iteration in epoch:\n", " losses_by_iteration.append(iteration.data.item())\n", "\n", "pd.Series(losses_by_iteration).plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation & Error Analysis" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "trained_model = torch.load(\"checkpoints/model_epoch_2\")" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='mps', index=0)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trained_model.device" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "def eval_on_test():\n", " model.eval()\n", " with torch.no_grad():\n", " test_loss = 0\n", " for batch in test_data_loader:\n", " inputs, labels = batch\n", " inputs = tokenizer(inputs, return_tensors='pt', padding=True, truncation=True)\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", " outputs = trained_model(**inputs)\n", " loss = loss_fn(outputs.logits, labels)\n", "\n", " test_loss += loss * inputs.input_ids.size()[0]\n", " acc, f1 = compute_metrics(outputs.logits.cpu(), labels.cpu())\n", "\n", " test_loss /= len(test_data_loader.dataset)\n", " print(\"Loss on TEST: \", test_loss)\n", " print(\"Acc on TEST: \", acc)\n", " print(\"F1 on TEST: \", f1)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss on TEST: tensor(0.1681, device='mps:0')\n", "Acc on TEST: 0.9375\n", "F1 on TEST: 0.9444444444444444\n" ] } ], "source": [ "eval_on_test()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**93,75%** accuracy on unseen data. Let's get a better understanding of the errors the model makes on the different categories - this way we can check if the class imbalance of the dataset had a negative impact." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Saving" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "torch.save(trained_model, \"../models/model\")\n", "torch.save(tokenizer, \"../models/tokenizer\")" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "with open(\"../models/label_dict\", 'wb') as file:\n", " pickle.dump(label_dict, file)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('ive been too deep down in the swamps swimming in muddy waters tortured by fear feeling lonely and lost',\n", " 'i am feeling a little apprehensive about the whole thing',\n", " 'i lift different now because it hurt so bad the day it happened that i can t get it out of my mind and i feel myself being a bit timid',\n", " 'im pretty sure and its been about a week and a half so although im feeling kind of betrayed and disillusioned by men at the moment everythings okay',\n", " 'i asked her what she meant by shes gonna feel jealous having loada of girls over me and then she said maybee i do like you a bitt',\n", " 'i was so uncomfortable and feeling weird feelings but wasn t sure if they were contractions since i never really felt contractions with jared until they jacked me up with pitocin',\n", " 'im not sure if im more at peace with our situation or if im just not feeling as bitter about it but in the past five months something has changed within me',\n", " 'i feel a bit calm now',\n", " 'i see all my friends posting pics and status updates of where they are going or what they are doing and i feel a bit jealous knowing it s not something i can get out and enjoy',\n", " 'i am feeling overwhelmed by trying to do it all that i think on the women before me',\n", " 'i dont know if i feel thrilled at finally getting to go camping again with people i like and know first time where thats happened',\n", " 'i feel like we may be coming to the point in the tv series where the show is incredibly popular but sadly the writers are coming to the end of their story lines and soon there will be nothing left to keep the plot a float',\n", " 'im sad if some people are unhappy about the flag for religious reasons but i know many religious people who do not feel it goes against their faith and they are very supportive',\n", " 'i carried my phone in my pocket and didn t feel the pull to get lost in it',\n", " 'i feel this strange sort of liberation',\n", " 'i feel inside this life is like a game sometimes then you came around me the walls just dissapeared nothing to surround me keep me from my fears im unprotected see how ive opened up youve made me trust coz ive never felt like this before im naked around you does it show'),\n", " tensor([0, 4, 4, 0, 3, 5, 3, 1, 3, 5, 1, 1, 2, 0, 5, 0])]" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(iter(test_data_loader)) " ] } ], "metadata": { "kernelspec": { "display_name": "hugging", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }