diff --git "a/xnli_model.ipynb" "b/xnli_model.ipynb"
new file mode 100644--- /dev/null
+++ "b/xnli_model.ipynb"
@@ -0,0 +1,9640 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ScLdn2bHZFKg",
+ "outputId": "5557a9de-4d38-436d-8251-00c96472a8f3"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting datasets\n",
+ " Downloading datasets-2.10.1-py3-none-any.whl (469 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m469.0/469.0 KB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting transformers[sentencepiece]\n",
+ " Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.3/6.3 MB\u001b[0m \u001b[31m69.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1)\n",
+ "Collecting multiprocess\n",
+ " Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.0/132.0 KB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.2.0\n",
+ " Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.3/190.3 KB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)\n",
+ "Collecting dill<0.3.7,>=0.3.0\n",
+ " Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 KB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1)\n",
+ "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.22.4)\n",
+ "Collecting xxhash\n",
+ " Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.0/213.0 KB\u001b[0m \u001b[31m23.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.4)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)\n",
+ "Collecting responses<0.19\n",
+ " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
+ "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n",
+ " Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m105.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers[sentencepiece]) (3.9.0)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers[sentencepiece]) (2022.6.2)\n",
+ "Collecting sentencepiece!=0.1.92,>=0.1.91\n",
+ " Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m73.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: protobuf<=3.20.2 in /usr/local/lib/python3.8/dist-packages (from transformers[sentencepiece]) (3.19.6)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (3.0.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.5.0)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
+ "Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n",
+ "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
+ "Installing collected packages: tokenizers, sentencepiece, xxhash, dill, responses, multiprocess, huggingface-hub, transformers, datasets\n",
+ "Successfully installed datasets-2.10.1 dill-0.3.6 huggingface-hub-0.12.1 multiprocess-0.70.14 responses-0.18.0 sentencepiece-0.1.97 tokenizers-0.13.2 transformers-4.26.1 xxhash-3.2.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install datasets transformers[sentencepiece]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 603,
+ "referenced_widgets": [
+ "7dc5cf197d0645f0bf825b6e9db4b896",
+ "ed38b3a6107947148f142bcc05c3df08",
+ "871e1e1bbc354ce9b0ef73a9d979daaf",
+ "7e537d6ef1cc4c5cbd273a5c336d70d0",
+ "d32ff02a9ece4c11844ec133eaef7b8b",
+ "468cb9a64c60416a90deb593c616626c",
+ "fa06607517924e26af2c3730cbdc80c2",
+ "4616fe6206cb48fe9e60c767d44309d6",
+ "90a8b4e90e79484bb4082f085af55950",
+ "d76928ae46fe4f08b6ee1f1f875721a2",
+ "8214a8a5b4b0465591b1dd2b2064a3b3",
+ "e6e0d7e3b4e1406695e20632cbc4364b",
+ "e8eb3daad6e444159bd3ef4d0f3f93f9",
+ "b3546212ef1e4d99849f78c8c32bda3b",
+ "1c6cb090033344ab86f54ee4f2fdcdf8",
+ "1fac171a5f0b4243be54689e891c8522",
+ "bb50ed6cc7974d9e80ee7f390781204f",
+ "7428e7a29e854fbe802a02a2fddb242d",
+ "48124689e1674cb38fadde97c377ca36",
+ "deb064d0790649e49805915128d62394",
+ "4344404a25794d74a0edd4104778e79d",
+ "bc7d8485a5e94f5abb2ca5cfaad1c290",
+ "bdacff3ac0bd4b0ca9b883e945c0ddeb",
+ "24c233bdd4054d578256e7503000d0f7",
+ "1c9f9aa2164040449b4e7766278d6cdc",
+ "86f8a6744b4e48619d524f82322946ff",
+ "682a5867e9df415f8c5c93673674271f",
+ "cbbca9a7bf54417981d280c40d972eda",
+ "a08086db95f24bbeaa2a66f79aa221cd",
+ "d38f78e3b90a448099d1eb496b07f083",
+ "21dbbdab77e649019172559075b09e01",
+ "58c78f6aeb75441998ba8886598bb1d1",
+ "5f46c061bb6447f1bc4055469682cb3f",
+ "ca8e307f01bc4315a108fc657c064104",
+ "e81f560319624ada9bb45871f3a00a91",
+ "e29384e17a5c47bcb8b472f3eb0946c9",
+ "6bddcd5055fd46c0885b3e9aea1e8857",
+ "96729ec0d47a476a95aca46221c1b7bc",
+ "b0a28c6bc7a946599bd86575be206b64",
+ "8c9202cf4e1542de88a74c53b8e0c44b",
+ "54b15e99206549fea092c8f0c4d853be",
+ "eda72f53dc4741afbb3ff33041db9f82",
+ "4e3a0a282d5d43f7940c6b521937b074",
+ "af91211491ed4209a905b6afffd1a7e1",
+ "074aed9dc4ec42c5836cb8df7a427da5",
+ "d7966925f0cb4c199519a003861e0363",
+ "af897d46e63e495d86e1897a02a2c3c6",
+ "db9df9c91de644bc978e00ca2770ff0e",
+ "6486c55fcc7c46868164eaed11487217",
+ "97529f7b40494ca4939d49142695947c",
+ "2295c16f5f20474a9e4579848ad3302d",
+ "07c399fa32d5400884a93a357dc8345b",
+ "b53b2bbead7b406a952168f81d5c262d",
+ "691533aa5b004289a2f7f06db392013b",
+ "a0569c75957b454e899616f76cb19783",
+ "f1ae69f0dbe74a5b9e01937561b3a9bb",
+ "a926d74caa22408d9e29ddf57ad9656b",
+ "b7a8c776adff4dc79687739f0c6fdfec",
+ "66f819b77cf84eeb82b40ad56b353659",
+ "208d0d2f8e764f009926c66d448add68",
+ "9f05a7d35f5442aab079e17aa78e9d14",
+ "7742f397fb264696a6346819f6cfc522",
+ "cd468c50ab79495c8ee5d2a3387edd62",
+ "ec3cde9cbcd745009a2736e6c5416eb6",
+ "fb55f28d75544958b32d25003ff69fa9",
+ "051457b89b46410b9871c5dffb2a2366",
+ "a8819d3249e94daf8c50bdcefd42f31d",
+ "9c2bf9e5453343c4b27b7436cfc5c4e5",
+ "5ef6d5e327b9407dba157d2ecf8da5c4",
+ "6b2ffb369b6f4f2c8a638dc3606bc731",
+ "d6d31a25e51d4aa3a930262a5983b921",
+ "14df8f66185a4e2989e4226fd7c4539f",
+ "b0977a7d702b4554ba1abb391b208798",
+ "4801c33b951240bca87d905c2b35b8d5",
+ "3fb4e307fe07441f97e3af1cf4b4e30b",
+ "e145b17c4c8f47d3bfc509fa5c9b6e37",
+ "cb480f97fe944c21900f18adefc186b9",
+ "dd94596f55624b01b33eebca8d144983",
+ "8c5ff33ccb0149029f85636a847e8dec",
+ "2bb7c83571cf4ecc9beab86caf785d64",
+ "50f4807cbca6470ab75c05a159b4ed21",
+ "4d76058688854ff9b9e869407c6c8cfd",
+ "2c06f883f4da403ca0890c7407e0d2e4",
+ "0843812656ae4c8aa541e966f90333d4",
+ "e3c0434a4eb0464aa76dfbeac206c833",
+ "05dfb651142942988f8a9d07de2bb88e",
+ "5932421f7540409fbd950448454fec89",
+ "c275778693e9437bb209444b5743109b",
+ "1f76a06eb5ee4fc7a48367f62deb452b",
+ "92da43d475374950a8fed748a2d5f27d",
+ "7dde417459764908bd1fb3e4b0c9afe3",
+ "87c35b48b8bf466690efafe103f94c2b",
+ "845f86f367694c5d93267586c34527de",
+ "6ecc052127b44cbcb1cdc11423fdd46c",
+ "95780e03b78a4bb585cdb2fa26436657",
+ "45c7314f894f4d5690f4be9bc950d9f5",
+ "23cb4087845949d1b841becb1c0022e8",
+ "1120f4ed9a784137ab3747c9e8e4c515",
+ "ac2985076cac43a5bf804917270c6115",
+ "10319e6e18af41398356c6480cdd6805",
+ "076de24dcc814f22947efee893aae1eb",
+ "6c323d0324fc4990890aaf6ace9eb9f5",
+ "1b1b031961824dea8b97abfa303eba27",
+ "0fc980945247465998adada7b38d5ac1",
+ "11fdcf8da69a425c896510840fb3bd0c",
+ "18de97bbe4ac45b58dc73192a0be153e",
+ "1427389ba9304f0694035e49f9f4ffb8",
+ "b1b90df89af3453d8eb63be5d64cb230",
+ "3066f35c1d384a62843d9ecf97faff48",
+ "e15c66b4f49f49c293413696e74a43e0",
+ "5fb2963a30b64a8f945411486846bd7f",
+ "0922a9153797449dbbe011ef9d3c1e09",
+ "20f5fa65a51644cd83e8c82db4a21074",
+ "7f238a99a82349c1b44051dc14343703",
+ "16cf059dba3c4cb2b86352ebb0a901af",
+ "a647f56e80f144ab91a7745188807847",
+ "bdacf0f0dcdf40e09b56ab70fe5bb987",
+ "0898b3a4a8bc4735b203108c166c2cf0",
+ "7548a410fbd94386a081d63eb7d31496",
+ "9f59d950a7df4aa893c788181c3270c0",
+ "039092d89ef7460cb38855cd0d3e4323"
+ ]
+ },
+ "id": "EKHMCuvJZGBI",
+ "outputId": "7ebcb889-76d8-423f-ccc2-649c957b09cf"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading builder script: 0%| | 0.00/8.78k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "7dc5cf197d0645f0bf825b6e9db4b896"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading metadata: 0%| | 0.00/36.6k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "e6e0d7e3b4e1406695e20632cbc4364b"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading readme: 0%| | 0.00/18.3k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "bdacff3ac0bd4b0ca9b883e945c0ddeb"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Downloading and preparing dataset xnli/en to /root/.cache/huggingface/datasets/xnli/en/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd...\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading data files: 0%| | 0/2 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "ca8e307f01bc4315a108fc657c064104"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/466M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "074aed9dc4ec42c5836cb8df7a427da5"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/17.9M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "f1ae69f0dbe74a5b9e01937561b3a9bb"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Extracting data files: 0%| | 0/2 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a8819d3249e94daf8c50bdcefd42f31d"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating train split: 0%| | 0/392702 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "dd94596f55624b01b33eebca8d144983"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating test split: 0%| | 0/5010 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "1f76a06eb5ee4fc7a48367f62deb452b"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating validation split: 0%| | 0/2490 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "10319e6e18af41398356c6480cdd6805"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Dataset xnli downloaded and prepared to /root/.cache/huggingface/datasets/xnli/en/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd. Subsequent calls will reuse this data.\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "5fb2963a30b64a8f945411486846bd7f"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['premise', 'hypothesis', 'label'],\n",
+ " num_rows: 392702\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['premise', 'hypothesis', 'label'],\n",
+ " num_rows: 5010\n",
+ " })\n",
+ " validation: Dataset({\n",
+ " features: ['premise', 'hypothesis', 'label'],\n",
+ " num_rows: 2490\n",
+ " })\n",
+ "})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 2
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "english_dataset = load_dataset(\"xnli\", \"en\")\n",
+ "english_dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 423
+ },
+ "id": "g3XPhmUMZGEK",
+ "outputId": "ad5a7b0c-2afb-403a-aa80-a75f1eaf0de6"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " premise \\\n",
+ "0 Conceptually cream skimming has two basic dime... \n",
+ "1 you know during the season and i guess at at y... \n",
+ "2 One of our number will carry out your instruct... \n",
+ "3 How do you know ? All this is their informatio... \n",
+ "4 yeah i tell you what though if you go price so... \n",
+ "... ... \n",
+ "392697 Clearly , California can - and must - do better . \n",
+ "392698 It was once regarded as the most beautiful str... \n",
+ "392699 Houseboats are a beautifully preserved traditi... \n",
+ "392700 Obituaries fondly recalled his on-air debates ... \n",
+ "392701 in that other you know uh that i should do it ... \n",
+ "\n",
+ " hypothesis label \n",
+ "0 Product and geography are what make cream skim... 1 \n",
+ "1 You lose the things to the following level if ... 0 \n",
+ "2 A member of my team will execute your orders w... 0 \n",
+ "3 This information belongs to them . 0 \n",
+ "4 The tennis shoes have a range of prices . 1 \n",
+ "... ... ... \n",
+ "392697 California cannot do any better . 2 \n",
+ "392698 So many of the original buildings had been rep... 1 \n",
+ "392699 The tradition of houseboats originated while t... 0 \n",
+ "392700 The obituaries were beautiful and written in k... 1 \n",
+ "392701 My husband has been so overworked lately that ... 1 \n",
+ "\n",
+ "[392702 rows x 3 columns]"
+ ],
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " premise | \n",
+ " hypothesis | \n",
+ " label | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Conceptually cream skimming has two basic dime... | \n",
+ " Product and geography are what make cream skim... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " you know during the season and i guess at at y... | \n",
+ " You lose the things to the following level if ... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " One of our number will carry out your instruct... | \n",
+ " A member of my team will execute your orders w... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " How do you know ? All this is their informatio... | \n",
+ " This information belongs to them . | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " yeah i tell you what though if you go price so... | \n",
+ " The tennis shoes have a range of prices . | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 392697 | \n",
+ " Clearly , California can - and must - do better . | \n",
+ " California cannot do any better . | \n",
+ " 2 | \n",
+ "
\n",
+ " \n",
+ " 392698 | \n",
+ " It was once regarded as the most beautiful str... | \n",
+ " So many of the original buildings had been rep... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 392699 | \n",
+ " Houseboats are a beautifully preserved traditi... | \n",
+ " The tradition of houseboats originated while t... | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 392700 | \n",
+ " Obituaries fondly recalled his on-air debates ... | \n",
+ " The obituaries were beautiful and written in k... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 392701 | \n",
+ " in that other you know uh that i should do it ... | \n",
+ " My husband has been so overworked lately that ... | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
392702 rows × 3 columns
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ],
+ "source": [
+ "# creating pandas dataframe to visualize data\n",
+ "import pandas as pd\n",
+ "df_pandas_train_en = pd.DataFrame(english_dataset['train'])\n",
+ "df_pandas_train_en"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "fqiGkiEhZGG4"
+ },
+ "outputs": [],
+ "source": [
+ "# defining loop to show five samples in the training dataset\n",
+ "\n",
+ "def show_samples(dataset, num_samples=5, seed=42):\n",
+ " sample = dataset[\"train\"].shuffle(seed=seed).select(range(num_samples))\n",
+ " for example in sample:\n",
+ " print(f\"\\n'>> Premise: {example['premise']}'\")\n",
+ " print(f\"'>> Hypothesis: {example['hypothesis']}'\")\n",
+ " print(f\"'>> Label: {example['label']}'\")\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "laA63ACzZGJr",
+ "outputId": "8ae2d89a-ecc1-40ce-95b6-cbba3b31fe5d"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "'>> Premise: I 'll hurry over that part .'\n",
+ "'>> Hypothesis: \" I 'll be quick with that part . \"'\n",
+ "'>> Label: 0'\n",
+ "\n",
+ "'>> Premise: Shall I tell you why you have been so vehement against Mr. Inglethorp ?'\n",
+ "'>> Hypothesis: I can tell you why you 're being so vehement against Mr. Inglethorp .'\n",
+ "'>> Label: 0'\n",
+ "\n",
+ "'>> Premise: well you know that brings up the interesting subject too you know what would you have who who who would determine what these people do'\n",
+ "'>> Hypothesis: It begs the question of who gets to say what the other people do .'\n",
+ "'>> Label: 0'\n",
+ "\n",
+ "'>> Premise: A great Sather made the sun remain in one place too long , and the heat became too great .'\n",
+ "'>> Hypothesis: It got too hot when a Sather kept the sun in one spot .'\n",
+ "'>> Label: 0'\n",
+ "\n",
+ "'>> Premise: Of course , it will be generally known to-morrow . \" John reflected .'\n",
+ "'>> Hypothesis: The news was about to break , and John had announced that he found out the newspaper would be announcing it tomorrow to the public .'\n",
+ "'>> Label: 1'\n"
+ ]
+ }
+ ],
+ "source": [
+ "# seeing a few samples of data in the dataset\n",
+ "show_samples(english_dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "sZrXKg91ZGMk"
+ },
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import datetime\n",
+ "\n",
+ "def format_time(elapsed):\n",
+ " '''\n",
+ " Takes a time in seconds and returns a string hh:mm:ss\n",
+ " '''\n",
+ " # Round to the nearest second.\n",
+ " elapsed_rounded = int(round((elapsed)))\n",
+ " \n",
+ " # Format as hh:mm:ss\n",
+ " return str(datetime.timedelta(seconds=elapsed_rounded))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "K4rGme5oZGPS"
+ },
+ "outputs": [],
+ "source": [
+ "def good_update_interval(total_iters, num_desired_updates):\n",
+ " '''\n",
+ " This function will try to pick an intelligent progress update interval \n",
+ " based on the magnitude of the total iterations.\n",
+ "\n",
+ " Parameters:\n",
+ " `total_iters` - The number of iterations in the for-loop.\n",
+ " `num_desired_updates` - How many times we want to see an update over the \n",
+ " course of the for-loop.\n",
+ " '''\n",
+ " # Divide the total iterations by the desired number of updates. Most likely\n",
+ " # this will be some ugly number.\n",
+ " exact_interval = total_iters / num_desired_updates\n",
+ "\n",
+ " # The `round` function has the ability to round down a number to, e.g., the\n",
+ " # nearest thousandth: round(exact_interval, -3)\n",
+ " #\n",
+ " # To determine the magnitude to round to, find the magnitude of the total,\n",
+ " # and then go one magnitude below that.\n",
+ "\n",
+ " # Get the order of magnitude of the total.\n",
+ " order_of_mag = len(str(total_iters)) - 1\n",
+ "\n",
+ " # Our update interval should be rounded to an order of magnitude smaller. \n",
+ " round_mag = order_of_mag - 1\n",
+ "\n",
+ " # Round down and cast to an int.\n",
+ " update_interval = int(round(exact_interval, -round_mag))\n",
+ "\n",
+ " # Don't allow the interval to be zero!\n",
+ " if update_interval == 0:\n",
+ " update_interval = 1\n",
+ "\n",
+ " return update_interval"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "BuVv0h8hZGSS"
+ },
+ "outputs": [],
+ "source": [
+ "import csv\n",
+ "\n",
+ "def check_gpu_mem():\n",
+ " '''\n",
+ " Uses Nvidia's SMI tool to check the current GPU memory usage.\n",
+ " Reported values are in \"MiB\". 1 MiB = 2^20 bytes = 1,048,576 bytes.\n",
+ " '''\n",
+ " \n",
+ " # Run the command line tool and get the results.\n",
+ " buf = os.popen('nvidia-smi --query-gpu=memory.total,memory.used --format=csv')\n",
+ "\n",
+ " # Use csv module to read and parse the result.\n",
+ " reader = csv.reader(buf, delimiter=',')\n",
+ "\n",
+ " # Use a pandas table just for nice formatting.\n",
+ " df = pd.DataFrame(reader)\n",
+ "\n",
+ " # Use the first row as the column headers.\n",
+ " new_header = df.iloc[0] #grab the first row for the header\n",
+ " df = df[1:] #take the data less the header row\n",
+ " df.columns = new_header #set the header row as the df header\n",
+ "\n",
+ " # Display the formatted table.\n",
+ " #display(df)\n",
+ "\n",
+ " return df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 81,
+ "referenced_widgets": [
+ "a124e5d0a02f4a06bd19c09061bed0bd",
+ "2061eed02320414795bd1c6fd0b45c23",
+ "64d60704523145108ed302301fabd9a1",
+ "fd3529056e37497eaf4e2a45709216b7",
+ "fca9505b30b04fdfb9431b58f733ae1a",
+ "5a53e07c6c1749f19df3d75750bffe04",
+ "848963f7e5c8476292aee3e7caf287ec",
+ "c36694205931434485bed8a6eeb6f4e5",
+ "2f01bcac34ec4808abccaa7de82289d7",
+ "3fe6ccc9a6a24ba2ba8552bd76a078f6",
+ "87b683b5b1fa4527ae93df680d64862b",
+ "9f2fdb60a9a64d6a8edda72b4c5670c3",
+ "6c56b3596512495c8aafeaacd13ebd19",
+ "e6a9a6f9407144469bc229f803d92bfc",
+ "6a9f734ec7a5468eb57c60491693c019",
+ "15ff0cd26b644629a65843eeaef27388",
+ "44e813d6aef741d5b8bdb9ec3681ed8c",
+ "6b1ac7dac55e434ca1ac89c71d8aad02",
+ "4c172f34d1da49419aabcc161c391760",
+ "7b576f3297cc4d19a87a16d89aa9e9c4",
+ "6dc595535eca45999a9d2c3076d76de4",
+ "423fbadce8634b409e36725736622df5"
+ ]
+ },
+ "id": "ogn2oIj6ZGUy",
+ "outputId": "16dcd901-e9ea-4639-bbbc-35bc1c58c173"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)tencepiece.bpe.model: 0%| | 0.00/5.07M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a124e5d0a02f4a06bd19c09061bed0bd"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/615 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "9f2fdb60a9a64d6a8edda72b4c5670c3"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from transformers import XLMRobertaTokenizer\n",
+ "\n",
+ "# Download the tokenizer for the XLM-Robert `large` model.\n",
+ "xlmr_tokenizer = XLMRobertaTokenizer.from_pretrained(\"xlm-roberta-base\" )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "XwioGtSXZGXe",
+ "outputId": "d7ed73dc-e652-4ba4-be46-e7229dc3066c"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Input IDs: [0, 28240, 2685, 155255, 38, 2, 2, 11249, 621, 398, 18925, 32, 2]\n",
+ "Tokens: ['', '▁Hey', '▁there', '▁reader', '!', '', '', '▁How', '▁are', '▁you', '▁today', '?', '']\n",
+ "\n",
+ "Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sentence_1 = \"Hey there reader!\"\n",
+ "sentence_2 = \"How are you today?\"\n",
+ "\n",
+ "# Encode the two sentences together.\n",
+ "encoded = xlmr_tokenizer.encode_plus(sentence_1, sentence_2)\n",
+ "\n",
+ "# Print the IDs of the resulting tokens.\n",
+ "print (\"Input IDs: \", encoded['input_ids'])\n",
+ "\n",
+ "# Convert the token IDs back to strings so we can check them out.\n",
+ "print (\"Tokens: \", xlmr_tokenizer.convert_ids_to_tokens(encoded['input_ids']))\n",
+ "\n",
+ "# The tokenizer returns an attention mask, which masks out PAD tokens. \n",
+ "# Since we aren't doing any padding yet, the mask is just all 1s. \n",
+ "print (\"\\nAttention Mask: \", encoded['attention_mask'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VYhiEqVfZGaE",
+ "outputId": "68119ec2-f269-44d5-eab6-6297b34d9141"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Tokenizing all examples to check sequence lengths...\n",
+ " Tokenized 0 samples.\n",
+ " Tokenized 30,000 samples.\n",
+ " Tokenized 60,000 samples.\n",
+ " Tokenized 90,000 samples.\n",
+ " Tokenized 120,000 samples.\n",
+ " Tokenized 150,000 samples.\n",
+ " Tokenized 180,000 samples.\n",
+ " Tokenized 210,000 samples.\n",
+ " Tokenized 240,000 samples.\n",
+ " Tokenized 270,000 samples.\n",
+ " Tokenized 300,000 samples.\n",
+ " Tokenized 330,000 samples.\n",
+ " Tokenized 360,000 samples.\n",
+ " Tokenized 390,000 samples.\n",
+ "DONE.\n",
+ " 392,702 samples\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "lengths_en = []\n",
+ "\n",
+ "labels_en = []\n",
+ "\n",
+ "print('Tokenizing all examples to check sequence lengths...')\n",
+ "\n",
+ "# Iterate through the dataset...\n",
+ "for ex in english_dataset['train']:\n",
+ "\n",
+ " # Retrieve the premise and hypothesis strings. \n",
+ " premise = ex['premise']\n",
+ " hypothesis = ex['hypothesis']\n",
+ "\n",
+ " \n",
+ " \n",
+ " # Report progress.\n",
+ " if ((len(lengths_en) % 30000) == 0):\n",
+ " print(' Tokenized {:,} samples.'.format(len(lengths_en)))\n",
+ " \n",
+ " # `tokenizer.encode` will tokenize the sentence, map the tokens to ids, \n",
+ " # and add the required special tokens.:\n",
+ " encoded = xlmr_tokenizer.encode(\n",
+ " premise,\n",
+ " hypothesis,\n",
+ " add_special_tokens = True,\n",
+ " )\n",
+ "\n",
+ " # Record the length.\n",
+ " lengths_en.append(len(encoded))\n",
+ "\n",
+ " labels_en.append(ex['label'])\n",
+ "\n",
+ "print('DONE.')\n",
+ "print('{:>10,} samples'.format(len(lengths_en)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Gu8zpmSZZGdI",
+ "outputId": "81615835-2490-42e2-cbd7-f96e15b881fd"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " Min length: 6 tokens\n",
+ " Max length: 468 tokens\n",
+ "Median length: 45 tokens\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(' Min length: {:,} tokens'.format(min(lengths_en)))\n",
+ "print(' Max length: {:,} tokens'.format(max(lengths_en)))\n",
+ "print('Median length: {:,} tokens'.format(int(np.median(lengths_en))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 427
+ },
+ "id": "xVz7vpSNZGiX",
+ "outputId": "d4a74308-7751-4e58-8131-98fc5924c33c"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
+ " warnings.warn(msg, FutureWarning)\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "