{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "xoIgh792ONRc" }, "source": [ "# A full training" ] }, { "cell_type": "markdown", "metadata": { "id": "P_itPOQhONRm" }, "source": [ "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "vNjGFVb5ONRo", "outputId": "dcc3fa1e-e2e2-41b5-9a92-eb0fb682d10d", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting datasets\n", " Downloading datasets-2.14.4-py3-none-any.whl (519 kB)\n", "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/519.3 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m112.6/519.3 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting evaluate\n", " Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.4/81.4 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting transformers[sentencepiece]\n", " Downloading transformers-4.32.1-py3-none-any.whl (7.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m93.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n", "Collecting dill<0.3.8,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.1)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m23.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n", "Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)\n", " Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m24.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Collecting responses<0.19 (from evaluate)\n", " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.12.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (2023.6.3)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[sentencepiece])\n", " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m95.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting safetensors>=0.3.1 (from transformers[sentencepiece])\n", " Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m81.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[sentencepiece])\n", " Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m72.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from transformers[sentencepiece]) (3.20.3)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.7.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.7.22)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", "Installing collected packages: tokenizers, sentencepiece, safetensors, xxhash, dill, responses, multiprocess, huggingface-hub, transformers, datasets, evaluate\n", "Successfully installed datasets-2.14.4 dill-0.3.7 evaluate-0.4.0 huggingface-hub-0.16.4 multiprocess-0.70.15 responses-0.18.0 safetensors-0.3.3 sentencepiece-0.1.99 tokenizers-0.13.3 transformers-4.32.1 xxhash-3.3.0\n", "Collecting accelerate\n", " Downloading accelerate-0.22.0-py3-none-any.whl (251 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.12.2)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.7.1)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.27.2)\n", "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", "Installing collected packages: accelerate\n", "Successfully installed accelerate-0.22.0\n" ] } ], "source": [ "!pip install datasets evaluate transformers[sentencepiece]\n", "!pip install accelerate\n", "# To run the training on TPU, you will need to uncomment the following line:\n", "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "k864HfF_ONRr", "outputId": "2eed1ad1-5b62-4a46-9723-244b01827d5b", "colab": { "base_uri": "https://localhost:8080/", "height": 561, "referenced_widgets": [ "eaf689331b344c6fbe46407137430e00", "e68682d157834f29a423bc7c8a68e2a6", "ae5c6dc88a7a4ceaa4be70771b692831", "ff0757c12ec74d8ab7d7b19c2e9f07d1", "d46e4aeb197a4d1e915cf59d9d309015", "b6fdfada97644113b802d18a32a5ca31", "c37fd08cd00c43febd56caa686c01aaa", "fd001e1c512948008151dabb285aea02", "23d071f934034910b48cb920b123fd49", "93c4cca9371d49e4935ab472d2cc2a76", "6cd370e6a77948cc95657544f27d0337", "e362a6e03a9c4c80983620881f345994", "475325f0647e4a5d8ef2cabe215596fd", "52b6f325b0414d77a4760418a36f9dfb", "54ea8c9f43754500893daca893ee8f24", "8f2d093f7bad45719647fb2ad16090dc", "8409b6ef1a4c451eaf08eb236d50deee", "f21c4670586e4edd9fd0722a50cbf563", "033cfa67af654c9194ee3a2a3f2c7b2f", "8de7c949aec346d592e35f939df99039", "fd6ed12374c945a2a8975b35cc24d6fd", "a6495b0af8c8486c97168b14ba0a45a9", "e26051b70a964b488821c664a8a30c8e", "933fc65e53d246bb957f389dde56706c", "96afc4565f0b4da1a54d5a199e2593d2", "576e1ed847d4400fb3b16a3c79c82d6a", "2dc23d696e20436dbaebeed6985d9deb", "ed025c6413ef44ccb08f59ff743669c0", "d003c5fab5c64428b68a3137d7121df3", "c4d42970bfc44131a1a51ce12cd1c833", "b09562fb647743e4b1671791a391e1c5", "9015219d7a69457798669f2bbafcfc6a", "a25723e6ecdd40a89ec5ac2221b5b22a", "7dac64abc6b74508a982f2dd314b0e62", "ba9d806b5d32415083a9927eeef0d22a", "22a8cd08954f48de862530c22e2e9d69", "64c9383101d14d20bfcbc53ef511d072", "8ae0c543ead042faa8d79f5874384c79", "2b50f3536721483db1703a878ef0b4f2", "5a8dfe0be2644e4b8ec997aee6c73b2b", "51a13dca906b48a38ed600b50eadc2a2", "fc7e1f69763b4d5e91cae9a1b20a330a", "89f0d75a3dcb4d129f3f84a936515e61", "e4904a9bf70249559da2a174ae6de76c", "1f784ec558a447feaefe148e49d80f6b", "f7e633353c064c1793d9646b10853d49", "d3531731df7a44708131b29c526a11b0", "0dcab3110786460190844e6defe8ed43", "7dbed8e861d14603a069bd4bbef43324", "806cfe4aca404e7897be1ccd9d29ac29", "631c011e8a494c15a799ed0d85341558", "ade2b81591204b728e02639819c5fe16", "23576920bd864e31b00bb7361882b6a2", "d0d58964d06d451c8c5fd605abde8205", "bad527cf3c9245cfbf1ba08c9a804281", "9e3fa4fe4776467f882d5ef77a68ae59", "61e5f7c50ccc4d9b960a5cc6d183e966", "03891e2fb2474b5f8d4d6ff41636acfb", "ded630325b9b4f529b340e616f15f79f", "a138336371104f378bcfcbd2143e3320", "d1e224a5e34842838a0c03d590c31aad", "a8861fe8c77d42f0b82ab4d6787dd4f0", "926318f5514748c1b45273b345be3e46", "259f4d365aec41b4a29d2cbf996e81a0", "d1ec0b88904247feb449f06398f9b9ca", "7e4ab643b3c94995bd97c5b2b6a361f3", "d87440bb0b734541a37ac101ba44dda1", "c52b1c5126c041dd869dd7cc9678336a", "4c80b02977d54a92b982a17a2972ab40", "40c1cd4e7aff4de09561e5075ff1eb35", "0d31e46b5b1e4ac5a70e35a2565286bf", "73707bc2868d4d86b967f1602ca884ca", "4a7f795880e7414d943f9efb24be03fe", "104252a494a64030ad0e601c53be453d", "cc62b5f4906f4eb58611e7956999b5f7", "91750e3f81424a50b35c23578f6b8762", "0b4d81cf34ec49cd8289d371e595199b", "33503596733d41ad96b6ffd2cfdd9c60", "a770d8bb3abf41279c99be6dc53139ec", "540c0df7a1804543ac3b9346ce65f44a", "f034ffd134e54c54ac2c899bacd57538", "829191fddd77482e9ca9458f207912cf", "ace3e2c63ab74cf7a5d2954d57089b88", "9f9318ca2c33441bb35c4b5471694ca5", "e50d521f7ba84c939f04a81614f968cd", "feb557fb55d04f4e81ce362755f31647", "fce923cf2c134c26ac5d77d21d9b1e0e", "a3f4149a1cf0494da213a16ec36f7c82", "bd14df41aac54069849f9da75dd7b645", "df5044a8c8a843e7a07de3b27fc5e52b", "54ba349dfd49471e8dfea82e9b7b3241", "eb8316db019f40578f2988d883a7ed41", "1025791a769b4d128e7ecf1ffe8a273a", "447d978e7c9d4b809f88f664480a2682", "13124f8e7174413b99059bcce09fd46f", "86a5af6a65aa4abda9875799ba1762d7", "66c6856400a54b228261ebeddda8e2a4", "9416322cadb4418e98e6ec62735c0c2b", "b1025c6c84494d43b323a29a75995c7c", "05bda1fc6fc94c6ca30c411d1000ea0a", "a59e6620c9624803a065597923f04a01", "e985a5c7c02447518eda212c960fd50e", "14ae59b8e1fd4fc3a313bcc88b14c419", "e1d3211aa12c4201b8e2c92b4edd190c", "e55c443d94fa4682b186d32db82a4cdd", "d76e7e56805d4bffad8894aad10a73b9", "f02732ebd910489290b2e12ab4e9b1d7", "e851336fe5ae4ceb988569d244deb3ab", "4823029cc03140299ec1af5d1ebfe67a", "bc7a3e2984d94e1d8cc5691abf2cff5e", "0c196ad36fde451c8b44fdd2964c0720", "af05eee3a68a4982aaf4115f96463406", "b18e51551a774becaf0dbf3aca685443", "759479dffe8c4cd292397a8b922b88d7", "d77015c08305428b9e7df6d7ff344b16", "449a1d545206400791d69c3ef33a1121", "b9d62a2248594046a86c8f572f3fcb42", "b23dc3e4f7114ff1823cf6ad63eb8478", "16e7d532560a426f919667c3f1e17130", "d9de03051dca4e63a75fff287926aea3", "ffb5c400c16848829bab39aa51e2c77d", "496378556e3b46048aac01133ffd96fb", "aba880f1de9846d7b629acf9f634ec89", "bf29b7b04b7d4a7a910bca90f20e5791", "d0a02bef657f411ead1befc79340f89d", "90e4c4b59e8f435ea4a4790d5ddeb43b", "b968754a84a142259f4dc4264a600a9f", "a93553fc5a54471ca802a411c09dcd7f", "138e998347b64c60b4cb46796a155c26", "d3c755622cb946c6bfe75c59c022a590", "7fa878be251643c7aa7d41bd4183cafd", "69bb5b5fdf37481280f6401d0dbb49c4", "709aa4c7bc5c4fdf981019fa2baf269a", "c96e91c4946e4bc5a21fad011cdf8e13", "25969bcd12164e80a6d8c755446558a6", "e2060c879be5471cb442c85964cd6f4d", "d96b52f72f494182ab031ace5f6c2364", "00c84e22939e4f659f622892d4686c34", "d0b72f36b8ee4c62a080c0457b51948c", "b922b17821074168ac5f9fa364c222cb", "3441badb8e24419c82ffc4a46669d63c", "90dd79d303674ee9a518a0dba124a364", "def1054b65b24fdf8714f02320a2d2d9", "06db26c8140a406abfe59a34088b6000", "9666ca267d4941aba65befbe5db0fc17", "58af380ce2584c8fba82f082ff50d2b9", "c091deff44ef4546aba33faf917e94b8", "52c1865782cf4bb398d4c087a8d6e75a", "c115e785a83546ce9255947bdb9d5f61", "1ab7b61c2e9549f6b705c5fd3868eed7", "c6f1b53ac8124d4ca16b905a2eda836b", "b7cb69816f6e4238a6e1dae1393fbb67", "76f3c41bb7d6401fa311d8168bd3cac9", "abc3a4e0f81e4df1af86933f22761eaf", "4d5d7dcfe51a445499d99f60735e318e", "94943ae661ad4dec956b5278d3223528", "00b3157879b0472fab9960a8fe5d3cbf", "a984e71b709e4e73b3c32d939a2f2edf", "25b6af5b785e4bb4bec1de249568a3df", "f49fff4d4eea4fec8f69088e9d57e34d", "d66627053ebb4d3facb62302286d89e2", "e483ca63d859462ebeb5ca2d25dd4672", "954ea9a67afa4bbd884708424e36321c", "8dc019cfa2a947f391554b5821463e41", "2bedada0b8d24e6d854a040ef1e3985f", "00914e777b3b4ff68851946a503c4cf2", "9264472c7c204b8b956dc893fe4cce4d", "bcb4f25e6c504392aa4bb54fd0db612a", "56fa3e8014a34637acc841087bd53ec1", "db22e549ab6a45599928dcf5ce144cfb", "3b87342e5a994f12a0af5b465bef73b8", "5da9cb1be3ac4719aa32e311c6a000cf", "5c86881a02614293ba6aa55f12ce9c54", "6d555a406fe8435394c0c5d4ea536428", "b015b43a78f44a41beeea35d4205b655", "7be58e0c41e84b5a977d6466501967e7", "8c51da6627674e0cb3e16159bcc1a4b3", "16de7db2a6d543a482d3c3eac8583923", "4e8faadc85e9478d81254bdba3ecd4d3", "8f0c2cd8cdbe4401a92668ecd6193b1d", "dcc28c1c8c24450681ef4647e4479ae0", "e645a0d30c9e4cc3a1d558e49edebf81", "93efe74b50514d62b009cbebdb91b547", "d187b7bbda91469d84ea3eafe750620a", "53a2e95a7e4b4d5bb8130534c4e08188", "65710438fa2b43599b3a5285d878d349", "6ff05b95ca7e45279c21623cf0baef05" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Downloading builder script: 0%| | 0.00/28.8k [00:00) torch.Size([8, 2])\n" ] } ], "source": [ "outputs = model(**batch)\n", "print(outputs.loss, outputs.logits.shape)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "xuvVm-dgONR3", "outputId": "ef5d02b7-2c52-471a-d91f-dc097f4a3e3c", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import AdamW\n", "\n", "optimizer = AdamW(model.parameters(), lr=5e-5)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "PNTxdvEWONR5", "outputId": "d86b6ed7-9d62-4f1e-912c-4423ef160a65", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "1377\n" ] } ], "source": [ "from transformers import get_scheduler\n", "\n", "num_epochs = 3\n", "num_training_steps = num_epochs * len(train_dataloader)\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=num_training_steps,\n", ")\n", "print(num_training_steps)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "sJyO1F-2ONR6", "outputId": "9ad5cdf1-da95-4930-deaf-e5176ef093f8", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "device(type='cuda')" ] }, "metadata": {}, "execution_count": 11 } ], "source": [ "import torch\n", "\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "model.to(device)\n", "device" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "17LSxvNkONR7", "outputId": "b1f84542-8be6-45ff-e9c7-a407f15b5279", "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "c29c33bd969c45349399e44939d1af7b", "2bbac8b2c2c94d9ba970e71dc48b487d", "ef777c7041da4ae1ba7918ce670c3003", "ce4b9897906c42918104f6bf3f736c84", "32956b707d0e476cba358d67d5c61e60", "5a1804af344b4843a9f8e3090941ad1b", "eed6f9122c5e46bd8711753ee9bffbc2", "be4d50c521c642eeb2e03b65de4ef7b2", "bbace1b853744944ac97165cec20774f", "7eeb6f08bfd84443a6fd9639109ce683", "5c9f479a1ce24e9c89279b06887c89ec" ] } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ " 0%| | 0/1377 [00:00