{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "BERT-explainability.ipynb", "provenance": [], "authorship_tag": "ABX9TyOm8dIRrumd5XNcc+fntVA5", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YCdGaMuy56TA", "outputId": "8f802262-55eb-4366-b772-89c4756224b3" }, "source": [ "!git clone https://github.com/hila-chefer/Transformer-Explainability.git\n", "\n", "import os\n", "os.chdir(f'./Transformer-Explainability')\n", "\n", "!pip install -r requirements.txt\n", "!pip install captum" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "fatal: destination path 'Transformer-Explainability' already exists and is not an empty directory.\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: Pillow>=8.1.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 1)) (9.4.0)\n", "Requirement already satisfied: einops==0.3.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 2)) (0.3.0)\n", "Requirement already satisfied: h5py==2.8.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 3)) (2.8.0)\n", "Requirement already satisfied: imageio==2.9.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 4)) (2.9.0)\n", "Collecting matplotlib==3.3.2\n", " Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n", "Requirement already satisfied: opencv_python in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 6)) (4.6.0.66)\n", "Requirement already satisfied: scikit_image==0.17.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 7)) (0.17.2)\n", "Requirement already satisfied: scipy==1.5.2 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 8)) (1.5.2)\n", "Requirement already satisfied: sklearn in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 9)) (0.0.post1)\n", "Requirement already satisfied: torch==1.7.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 10)) (1.7.0)\n", "Requirement already satisfied: torchvision==0.8.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 11)) (0.8.1)\n", "Requirement already satisfied: tqdm==4.51.0 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 12)) (4.51.0)\n", "Requirement already satisfied: transformers==3.5.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 13)) (3.5.1)\n", "Requirement already satisfied: utils==1.0.1 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 14)) (1.0.1)\n", "Requirement already satisfied: Pygments>=2.7.4 in /usr/local/lib/python3.8/dist-packages (from -r requirements.txt (line 15)) (2.14.0)\n", "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.21.6)\n", "Requirement already satisfied: six in /usr/local/lib/python3.8/dist-packages (from h5py==2.8.0->-r requirements.txt (line 3)) (1.15.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (1.4.4)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (3.0.9)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2.8.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (0.11.0)\n", "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2->-r requirements.txt (line 5)) (2022.12.7)\n", "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (3.0)\n", "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (2022.10.10)\n", "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit_image==0.17.2->-r requirements.txt (line 7)) (1.4.1)\n", "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.6)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (4.4.0)\n", "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch==1.7.0->-r requirements.txt (line 10)) (0.16.0)\n", "Requirement already satisfied: sacremoses in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.0.53)\n", "Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.19.6)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (3.9.0)\n", "Requirement already satisfied: sentencepiece==0.1.91 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.1.91)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (21.3)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2022.6.2)\n", "Requirement already satisfied: tokenizers==0.9.3 in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (0.9.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers==3.5.1->-r requirements.txt (line 13)) (2.25.1)\n", "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (4.0.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (1.24.3)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers==3.5.1->-r requirements.txt (line 13)) (2.10)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (1.2.0)\n", "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from sacremoses->transformers==3.5.1->-r requirements.txt (line 13)) (7.1.2)\n", "Installing collected packages: matplotlib\n", " Attempting uninstall: matplotlib\n", " Found existing installation: matplotlib 3.6.3\n", " Uninstalling matplotlib-3.6.3:\n", " Successfully uninstalled matplotlib-3.6.3\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed matplotlib-3.3.2\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: captum in /usr/local/lib/python3.8/dist-packages (0.6.0)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum) (3.3.2)\n", "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum) (1.7.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum) (1.21.6)\n", "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.16.0)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (4.4.0)\n", "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum) (0.6)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (0.11.0)\n", "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (9.4.0)\n", "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2022.12.7)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (1.4.4)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (3.0.9)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->captum) (1.15.0)\n" ] } ] }, { "cell_type": "code", "source": [ "!pip install captum==0.6.0\n", "!pip install matplotlib==3.3.2" ], "metadata": { "id": "zDPnh4lofcNw", "outputId": "3d585bbc-ff3b-4a09-b5bf-57bb4d46e830", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: captum==0.6.0 in /usr/local/lib/python3.8/dist-packages (0.6.0)\n", "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.7.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (1.21.6)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum==0.6.0) (3.6.3)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (4.4.0)\n", "Requirement already satisfied: future in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.16.0)\n", "Requirement already satisfied: dataclasses in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->captum==0.6.0) (0.6)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.4.4)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (1.0.7)\n", "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (9.4.0)\n", "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (2.8.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (21.3)\n", "Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (3.0.9)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (4.38.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum==0.6.0) (0.11.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7->matplotlib->captum==0.6.0) (1.15.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting matplotlib==3.3.2\n", " Using cached matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl (11.6 MB)\n", "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (9.4.0)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (0.11.0)\n", "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.21.6)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (3.0.9)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (1.4.4)\n", "Requirement already satisfied: certifi>=2020.06.20 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2022.12.7)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib==3.3.2) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib==3.3.2) (1.15.0)\n", "Installing collected packages: matplotlib\n", " Attempting uninstall: matplotlib\n", " Found existing installation: matplotlib 3.6.3\n", " Uninstalling matplotlib-3.6.3:\n", " Successfully uninstalled matplotlib-3.6.3\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "fastai 2.7.10 requires torchvision>=0.8.2, but you have torchvision 0.8.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed matplotlib-3.3.2\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "4-XGl_Zw6Aht" }, "source": [ "from transformers import BertTokenizer\n", "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n", "from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification\n", "from transformers import BertTokenizer\n", "from BERT_explainability.modules.BERT.ExplanationGenerator import Generator\n", "from transformers import AutoTokenizer\n", "\n", "from captum.attr import visualization\n", "import torch" ], "execution_count": 10, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VakYjrkC6C3S" }, "source": [ "model = BertForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\").to(\"cuda\")\n", "model.eval()\n", "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\n", "# initialize the explanations generator\n", "explanations = Generator(model)\n", "\n", "classifications = [\"NEGATIVE\", \"POSITIVE\"]\n" ], "execution_count": 11, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "jGRp376FPOvV" }, "source": [ "#Positive sentiment example" ] }, { "cell_type": "code", "metadata": { "id": "uSLZtv546H2z", "colab": { "base_uri": "https://localhost:8080/", "height": 219 }, "outputId": "26712e90-0b77-40b0-a908-fef13dd88bcd" }, "source": [ "# encode a sentence\n", "text_batch = [\"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.\"]\n", "encoding = tokenizer(text_batch, return_tensors='pt')\n", "input_ids = encoding['input_ids'].to(\"cuda\")\n", "attention_mask = encoding['attention_mask'].to(\"cuda\")\n", "\n", "# true class is positive - 1\n", "true_class = 1\n", "\n", "# generate an explanation for the input\n", "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n", "# normalize scores\n", "expl = (expl - expl.min()) / (expl.max() - expl.min())\n", "\n", "# get the model classification\n", "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n", "classification = output.argmax(dim=-1).item()\n", "# get class name\n", "class_name = classifications[classification]\n", "# if the classification is negative, higher explanation scores are more negative\n", "# flip for visualization\n", "if class_name == \"NEGATIVE\":\n", " expl *= (-1)\n", "\n", "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n", "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n", "vis_data_records = [visualization.VisualizationDataRecord(\n", " expl,\n", " output[0][classification],\n", " classification,\n", " true_class,\n", " true_class,\n", " 1, \n", " tokens,\n", " 1)]\n", "visualization.visualize_text(vis_data_records)" ], "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[('[CLS]', 0.0), ('this', 0.4267549514770508), ('movie', 0.30920878052711487), ('was', 0.2684089243412018), ('the', 0.33637329936027527), ('best', 0.6280889511108398), ('movie', 0.28546375036239624), ('i', 0.1863601952791214), ('have', 0.10115814208984375), ('ever', 0.1419338583946228), ('seen', 0.1898290067911148), ('!', 0.5944811105728149), ('some', 0.003896803595125675), ('scenes', 0.033401958644390106), ('were', 0.018588582053780556), ('ridiculous', 0.018908796831965446), (',', 0.0), ('but', 0.42920616269111633), ('acting', 0.43855082988739014), ('was', 0.500239372253418), ('great', 1.0), ('.', 0.014817383140325546), ('[SEP]', 0.0868983045220375)]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
11 (1.00)11.00 [CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
11 (1.00)11.00 [CLS] this movie was the best movie i have ever seen ! some scenes were ridiculous , but acting was great . [SEP]
" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "markdown", "metadata": { "id": "oO_k1BtSPVt3" }, "source": [ "#Negative sentiment example" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 219 }, "id": "gD4xcvovI1KI", "outputId": "e4a50a94-da4c-460e-b602-052b09cec28f" }, "source": [ "# encode a sentence\n", "text_batch = [\"I really didn't like this movie. Some of the actors were good, but overall the movie was boring.\"]\n", "encoding = tokenizer(text_batch, return_tensors='pt')\n", "input_ids = encoding['input_ids'].to(\"cuda\")\n", "attention_mask = encoding['attention_mask'].to(\"cuda\")\n", "\n", "# generate an explanation for the input\n", "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]\n", "# normalize scores\n", "expl = (expl - expl.min()) / (expl.max() - expl.min())\n", "\n", "# get the model classification\n", "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n", "classification = output.argmax(dim=-1).item()\n", "# get class name\n", "class_name = classifications[classification]\n", "# if the classification is negative, higher explanation scores are more negative\n", "# flip for visualization\n", "if class_name == \"NEGATIVE\":\n", " expl *= (-1)\n", "\n", "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n", "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n", "vis_data_records = [visualization.VisualizationDataRecord(\n", " expl,\n", " output[0][classification],\n", " classification,\n", " 1,\n", " 1,\n", " 1, \n", " tokens,\n", " 1)]\n", "visualization.visualize_text(vis_data_records)" ], "execution_count": 13, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[('[CLS]', -0.0), ('i', -0.19109757244586945), ('really', -0.1888734996318817), ('didn', -0.2894313633441925), (\"'\", -0.006574898026883602), ('t', -0.36788827180862427), ('like', -0.15249046683311462), ('this', -0.18922168016433716), ('movie', -0.0404353104531765), ('.', -0.019592661410570145), ('some', -0.02311306819319725), ('of', -0.0), ('the', -0.02295113168656826), ('actors', -0.09577538073062897), ('were', -0.013370633125305176), ('good', -0.0323222391307354), (',', -0.004366681911051273), ('but', -0.05878860130906105), ('overall', -0.33596664667129517), ('the', -0.21820111572742462), ('movie', -0.05482065677642822), ('was', -0.6248231530189514), ('boring', -1.0), ('.', -0.031107747927308083), ('[SEP]', -0.052539654076099396)]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (1.00)11.00 [CLS] i really didn ' t like this movie . some of the actors were good , but overall the movie was boring . [SEP]
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (1.00)11.00 [CLS] i really didn ' t like this movie . some of the actors were good , but overall the movie was boring . [SEP]
" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "markdown", "source": [ "# Choosing class for visualization example" ], "metadata": { "id": "UUn2_SMPNG-Y" } }, { "cell_type": "code", "source": [ "# encode a sentence\n", "text_batch = [\"I hate that I love you.\"]\n", "encoding = tokenizer(text_batch, return_tensors='pt')\n", "input_ids = encoding['input_ids'].to(\"cuda\")\n", "attention_mask = encoding['attention_mask'].to(\"cuda\")\n", "\n", "# true class is positive - 1\n", "true_class = 1\n", "\n", "# generate an explanation for the input\n", "target_class = 0\n", "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n", "# normalize scores\n", "expl = (expl - expl.min()) / (expl.max() - expl.min())\n", "\n", "# get the model classification\n", "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n", "\n", "# get class name\n", "class_name = classifications[target_class]\n", "# if the classification is negative, higher explanation scores are more negative\n", "# flip for visualization\n", "if class_name == \"NEGATIVE\":\n", " expl *= (-1)\n", "\n", "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n", "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n", "vis_data_records = [visualization.VisualizationDataRecord(\n", " expl,\n", " output[0][classification],\n", " classification,\n", " true_class,\n", " true_class,\n", " 1, \n", " tokens,\n", " 1)]\n", "visualization.visualize_text(vis_data_records)" ], "metadata": { "id": "VQVmMFnzhPoV", "outputId": "26a43f8a-340c-4821-b39c-80105a565810", "colab": { "base_uri": "https://localhost:8080/", "height": 219 } }, "execution_count": 14, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[('[CLS]', -0.0), ('i', -0.19790242612361908), ('hate', -1.0), ('that', -0.40287283062934875), ('i', -0.12505637109279633), ('love', -0.1307140290737152), ('you', -0.05467141419649124), ('.', -6.108225989009952e-06), ('[SEP]', -0.0)]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (0.91)11.00 [CLS] i hate that i love you . [SEP]
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (0.91)11.00 [CLS] i hate that i love you . [SEP]
" ] }, "metadata": {}, "execution_count": 14 } ] }, { "cell_type": "code", "source": [ "# encode a sentence\n", "text_batch = [\"I hate that I love you.\"]\n", "encoding = tokenizer(text_batch, return_tensors='pt')\n", "input_ids = encoding['input_ids'].to(\"cuda\")\n", "attention_mask = encoding['attention_mask'].to(\"cuda\")\n", "\n", "# true class is positive - 1\n", "true_class = 1\n", "\n", "# generate an explanation for the input\n", "target_class = 1\n", "expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]\n", "# normalize scores\n", "expl = (expl - expl.min()) / (expl.max() - expl.min())\n", "\n", "# get the model classification\n", "output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)\n", "\n", "# get class name\n", "class_name = classifications[target_class]\n", "# if the classification is negative, higher explanation scores are more negative\n", "# flip for visualization\n", "if class_name == \"NEGATIVE\":\n", " expl *= (-1)\n", "\n", "tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())\n", "print([(tokens[i], expl[i].item()) for i in range(len(tokens))])\n", "vis_data_records = [visualization.VisualizationDataRecord(\n", " expl,\n", " output[0][classification],\n", " classification,\n", " true_class,\n", " true_class,\n", " 1, \n", " tokens,\n", " 1)]\n", "visualization.visualize_text(vis_data_records)" ], "metadata": { "id": "WiQAWw0-imCg", "outputId": "a8c66996-dcd0-4132-a8b0-2346d9bf9c7b", "colab": { "base_uri": "https://localhost:8080/", "height": 219 } }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[('[CLS]', 0.0), ('i', 0.2725590765476227), ('hate', 0.17270179092884064), ('that', 0.23211266100406647), ('i', 0.17642731964588165), ('love', 1.0), ('you', 0.2465524971485138), ('.', 0.0), ('[SEP]', 0.00015733683540020138)]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (0.91)11.00 [CLS] i hate that i love you . [SEP]
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "text/html": [ "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
10 (0.91)11.00 [CLS] i hate that i love you . [SEP]
" ] }, "metadata": {}, "execution_count": 15 } ] } ] }