Spaces:
Runtime error
Runtime error
Alex Martin
commited on
Commit
•
4f61ed2
1
Parent(s):
0b83fb6
Add files via upload
Browse files- milestone3.ipynb +364 -0
milestone3.ipynb
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": []
|
7 |
+
},
|
8 |
+
"kernelspec": {
|
9 |
+
"name": "python3",
|
10 |
+
"display_name": "Python 3"
|
11 |
+
},
|
12 |
+
"language_info": {
|
13 |
+
"name": "python"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"cells": [
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": null,
|
20 |
+
"metadata": {
|
21 |
+
"colab": {
|
22 |
+
"base_uri": "https://localhost:8080/"
|
23 |
+
},
|
24 |
+
"id": "CyzBIGSg5um5",
|
25 |
+
"outputId": "9a986d1a-d4e4-4d8d-ec9d-a895031f1bf6"
|
26 |
+
},
|
27 |
+
"outputs": [
|
28 |
+
{
|
29 |
+
"output_type": "stream",
|
30 |
+
"name": "stdout",
|
31 |
+
"text": [
|
32 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
33 |
+
"Collecting datasets\n",
|
34 |
+
" Downloading datasets-2.11.0-py3-none-any.whl (468 kB)\n",
|
35 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.7/468.7 kB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
36 |
+
"\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets) (23.1)\n",
|
37 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets) (1.22.4)\n",
|
38 |
+
"Collecting dill<0.3.7,>=0.3.0\n",
|
39 |
+
" Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
|
40 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
41 |
+
"\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from datasets) (1.5.3)\n",
|
42 |
+
"Collecting aiohttp\n",
|
43 |
+
" Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
|
44 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m57.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
45 |
+
"\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (6.0)\n",
|
46 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (9.0.0)\n",
|
47 |
+
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (2.27.1)\n",
|
48 |
+
"Collecting huggingface-hub<1.0.0,>=0.11.0\n",
|
49 |
+
" Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n",
|
50 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m29.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
51 |
+
"\u001b[?25hCollecting responses<0.19\n",
|
52 |
+
" Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
|
53 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (4.65.0)\n",
|
54 |
+
"Collecting multiprocess\n",
|
55 |
+
" Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)\n",
|
56 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.9/132.9 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
57 |
+
"\u001b[?25hCollecting xxhash\n",
|
58 |
+
" Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
|
59 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m22.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
60 |
+
"\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (2023.4.0)\n",
|
61 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (23.1.0)\n",
|
62 |
+
"Collecting aiosignal>=1.1.2\n",
|
63 |
+
" Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
|
64 |
+
"Collecting yarl<2.0,>=1.0\n",
|
65 |
+
" Downloading yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (269 kB)\n",
|
66 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m269.4/269.4 kB\u001b[0m \u001b[31m27.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
67 |
+
"\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3\n",
|
68 |
+
" Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
|
69 |
+
"Collecting multidict<7.0,>=4.5\n",
|
70 |
+
" Downloading multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
|
71 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.2/114.2 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
72 |
+
"\u001b[?25hRequirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (2.0.12)\n",
|
73 |
+
"Collecting frozenlist>=1.1.1\n",
|
74 |
+
" Downloading frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n",
|
75 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.8/158.8 kB\u001b[0m \u001b[31m18.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
76 |
+
"\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n",
|
77 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (3.11.0)\n",
|
78 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n",
|
79 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (1.26.15)\n",
|
80 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets) (3.4)\n",
|
81 |
+
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2.8.2)\n",
|
82 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2022.7.1)\n",
|
83 |
+
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
|
84 |
+
"Installing collected packages: xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets\n",
|
85 |
+
"Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.11.0 dill-0.3.6 frozenlist-1.3.3 huggingface-hub-0.14.1 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 xxhash-3.2.0 yarl-1.9.2\n",
|
86 |
+
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
87 |
+
"Collecting transformers\n",
|
88 |
+
" Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n",
|
89 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m46.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
90 |
+
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n",
|
91 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n",
|
92 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n",
|
93 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.11.0)\n",
|
94 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n",
|
95 |
+
"Collecting tokenizers!=0.11.3,<0.14,>=0.11.1\n",
|
96 |
+
" Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n",
|
97 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m94.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
98 |
+
"\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n",
|
99 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.1)\n",
|
100 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.14.1)\n",
|
101 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n",
|
102 |
+
"Requirement already satisfied: fsspec in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2023.4.0)\n",
|
103 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n",
|
104 |
+
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2.0.12)\n",
|
105 |
+
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n",
|
106 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n",
|
107 |
+
"Installing collected packages: tokenizers, transformers\n",
|
108 |
+
"Successfully installed tokenizers-0.13.3 transformers-4.28.1\n"
|
109 |
+
]
|
110 |
+
}
|
111 |
+
],
|
112 |
+
"source": [
|
113 |
+
"!pip install datasets\n",
|
114 |
+
"!pip install transformers\n",
|
115 |
+
"import pandas as pd\n",
|
116 |
+
"from sklearn.model_selection import train_test_split\n",
|
117 |
+
"import numpy as np\n",
|
118 |
+
"import transformers\n",
|
119 |
+
"import torch\n",
|
120 |
+
"import csv\n",
|
121 |
+
"from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler\n",
|
122 |
+
"from transformers import DistilBertTokenizer, DistilBertModel, DistilBertForSequenceClassification, AdamW"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"source": [
|
128 |
+
"filename = \"/content/sample_data/train.csv\"\n",
|
129 |
+
"df = pd.read_csv(filename)\n",
|
130 |
+
"df.head()\n",
|
131 |
+
"df.drop(['id'], inplace=True, axis=1)\n",
|
132 |
+
"newdf = pd.DataFrame()\n",
|
133 |
+
"newdf['text'] = df['comment_text']\n",
|
134 |
+
"newdf['labels'] = df.iloc[:, 1:].values.tolist()\n",
|
135 |
+
"\n",
|
136 |
+
"newdf.head()"
|
137 |
+
],
|
138 |
+
"metadata": {
|
139 |
+
"colab": {
|
140 |
+
"base_uri": "https://localhost:8080/",
|
141 |
+
"height": 381
|
142 |
+
},
|
143 |
+
"id": "XQEDvn-7ksXU",
|
144 |
+
"outputId": "960bd74f-2533-4eab-9800-643823e14f2f"
|
145 |
+
},
|
146 |
+
"execution_count": null,
|
147 |
+
"outputs": [
|
148 |
+
{
|
149 |
+
"output_type": "error",
|
150 |
+
"ename": "FileNotFoundError",
|
151 |
+
"evalue": "ignored",
|
152 |
+
"traceback": [
|
153 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
154 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
155 |
+
"\u001b[0;32m<ipython-input-3-dadeecb77ee2>\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"/content/sample_data/train.csv\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'id'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnewdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
156 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/util/_decorators.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnew_arg_name\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_arg_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 211\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 212\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
157 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/util/_decorators.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0mstacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfind_stack_level\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 330\u001b[0m )\n\u001b[0;32m--> 331\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 332\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;31m# error: \"Callable[[VarArg(Any), KwArg(Any)], Any]\" has no\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
158 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, error_bad_lines, warn_bad_lines, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options)\u001b[0m\n\u001b[1;32m 948\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwds_defaults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 950\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 951\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 952\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
159 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 603\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 605\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 606\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 607\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
160 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1440\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandles\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIOHandles\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1442\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1443\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1444\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
161 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1733\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m\"b\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1735\u001b[0;31m self.handles = get_handle(\n\u001b[0m\u001b[1;32m 1736\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1737\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
162 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/pandas/io/common.py\u001b[0m in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 854\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 855\u001b[0m \u001b[0;31m# Encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 856\u001b[0;31m handle = open(\n\u001b[0m\u001b[1;32m 857\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
163 |
+
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/content/sample_data/train.csv'"
|
164 |
+
]
|
165 |
+
}
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"source": [
|
171 |
+
"epoch = 1\n",
|
172 |
+
"max_len = 128\n",
|
173 |
+
"batch_size = 5\n",
|
174 |
+
"\n",
|
175 |
+
"train_df, val_df = train_test_split(newdf, test_size=0.2, random_state=42)\n",
|
176 |
+
"\"\"\"\n",
|
177 |
+
"DistilBertTokenizer\n",
|
178 |
+
"torch.utils.data.Dataset\n",
|
179 |
+
"inputs = self.tokenizer.encode_plus\n",
|
180 |
+
"DataLoader\n",
|
181 |
+
"PreTrainedModel\n",
|
182 |
+
"DistilBertForSequenceClassification\n",
|
183 |
+
"DistilBertConfig\n",
|
184 |
+
"model = DistilBertClassifier2(config)\n",
|
185 |
+
"model.to(device)\n",
|
186 |
+
"torch.optim.Adam\n",
|
187 |
+
"tokenizer.encode_plus\n",
|
188 |
+
"tokenizer.save_pretrained(\"model\")\n",
|
189 |
+
"model.save_pretrained(\"model\")\n",
|
190 |
+
"\"\"\""
|
191 |
+
],
|
192 |
+
"metadata": {
|
193 |
+
"colab": {
|
194 |
+
"base_uri": "https://localhost:8080/",
|
195 |
+
"height": 235
|
196 |
+
},
|
197 |
+
"id": "DO8fKxgnwIPz",
|
198 |
+
"outputId": "d0f73814-62a0-4d74-9353-3d4ce90b6d1b"
|
199 |
+
},
|
200 |
+
"execution_count": null,
|
201 |
+
"outputs": [
|
202 |
+
{
|
203 |
+
"output_type": "error",
|
204 |
+
"ename": "NameError",
|
205 |
+
"evalue": "ignored",
|
206 |
+
"traceback": [
|
207 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
208 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
209 |
+
"\u001b[0;32m<ipython-input-4-232584db16d1>\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mtrain_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_test_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewdf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m42\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \"\"\"\n\u001b[1;32m 7\u001b[0m \u001b[0mDistilBertTokenizer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
210 |
+
"\u001b[0;31mNameError\u001b[0m: name 'newdf' is not defined"
|
211 |
+
]
|
212 |
+
}
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"source": [
|
218 |
+
"class DS(Dataset):\n",
|
219 |
+
" def __init__(self, dataframe, max_len):\n",
|
220 |
+
" self.data = dataframe\n",
|
221 |
+
" self.max_len = max_len\n",
|
222 |
+
" self.text = dataframe.text\n",
|
223 |
+
" self.targets = self.data.labels\n",
|
224 |
+
" self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
|
225 |
+
" \n",
|
226 |
+
" def __len__(self):\n",
|
227 |
+
" return len(self.data)\n",
|
228 |
+
" \n",
|
229 |
+
" def __getitem__(self, index):\n",
|
230 |
+
" text = str(self.text.iloc[index])\n",
|
231 |
+
" text = \" \".join(text.split())\n",
|
232 |
+
"\n",
|
233 |
+
" inputs = self.tokenizer.encode_plus(\n",
|
234 |
+
" text, None,\n",
|
235 |
+
" add_special_tokens=True,\n",
|
236 |
+
" max_length=self.max_len,\n",
|
237 |
+
" pad_to_max_length=True,\n",
|
238 |
+
" return_token_type_ids=True)\n",
|
239 |
+
" ids = inputs['input_ids']\n",
|
240 |
+
" mask = inputs['attention_mask']\n",
|
241 |
+
" token_type_ids = inputs[\"token_type_ids\"]\n",
|
242 |
+
" return {\n",
|
243 |
+
" 'ids': torch.tensor(ids, dtype=torch.long),\n",
|
244 |
+
" 'attention_mask': torch.tensor(mask, dtype=torch.long),\n",
|
245 |
+
" 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n",
|
246 |
+
" 'labels': torch.tensor(self.targets.iloc[index], dtype=torch.float)\n",
|
247 |
+
" }\n"
|
248 |
+
],
|
249 |
+
"metadata": {
|
250 |
+
"id": "i80qLafpzWDh"
|
251 |
+
},
|
252 |
+
"execution_count": null,
|
253 |
+
"outputs": []
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"source": [
|
258 |
+
"traindata = DS(train_df, max_len)\n",
|
259 |
+
"validdata = DS(val_df, max_len)\n",
|
260 |
+
"train_loader = DataLoader(traindata, batch_size=batch_size, shuffle=True)\n",
|
261 |
+
"tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n"
|
262 |
+
],
|
263 |
+
"metadata": {
|
264 |
+
"id": "EMScGH58Poaw",
|
265 |
+
"colab": {
|
266 |
+
"base_uri": "https://localhost:8080/",
|
267 |
+
"height": 201
|
268 |
+
},
|
269 |
+
"outputId": "de081257-fb6c-4c73-c54d-e88dd1e2603f"
|
270 |
+
},
|
271 |
+
"execution_count": null,
|
272 |
+
"outputs": [
|
273 |
+
{
|
274 |
+
"output_type": "error",
|
275 |
+
"ename": "NameError",
|
276 |
+
"evalue": "ignored",
|
277 |
+
"traceback": [
|
278 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
279 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
280 |
+
"\u001b[0;32m<ipython-input-6-84bc2329fae8>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtraindata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDS\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mvaliddata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDS\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtrain_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraindata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
281 |
+
"\u001b[0;31mNameError\u001b[0m: name 'train_df' is not defined"
|
282 |
+
]
|
283 |
+
}
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"metadata": {
|
289 |
+
"id": "fb9-Yr9YDZqo",
|
290 |
+
"colab": {
|
291 |
+
"base_uri": "https://localhost:8080/",
|
292 |
+
"height": 235
|
293 |
+
},
|
294 |
+
"outputId": "0664e5d0-55cb-4b58-e75a-b9acdab82e73"
|
295 |
+
},
|
296 |
+
"source": [
|
297 |
+
"device = torch.device('cuda')\n",
|
298 |
+
"\n",
|
299 |
+
"model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=6, problem_type=\"multi_label_classification\")\n",
|
300 |
+
"model.to(device)\n",
|
301 |
+
"model.train()\n",
|
302 |
+
"\n",
|
303 |
+
"optimizer = AdamW(model.parameters(), lr=5e-5)\n"
|
304 |
+
],
|
305 |
+
"execution_count": null,
|
306 |
+
"outputs": [
|
307 |
+
{
|
308 |
+
"output_type": "error",
|
309 |
+
"ename": "NameError",
|
310 |
+
"evalue": "ignored",
|
311 |
+
"traceback": [
|
312 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
313 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
314 |
+
"\u001b[0;32m<ipython-input-1-26b16f11a588>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDistilBertForSequenceClassification\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_pretrained\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'distilbert-base-uncased'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproblem_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"multi_label_classification\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
315 |
+
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
|
316 |
+
]
|
317 |
+
}
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"source": [
|
323 |
+
"for i in range(epoch):\n",
|
324 |
+
" for batch in train_loader:\n",
|
325 |
+
" optimizer.zero_grad()\n",
|
326 |
+
" ids = batch['input_ids'].to(device)\n",
|
327 |
+
" mask = batch['attention_mask'].to(device)\n",
|
328 |
+
" labels = batch['labels'].to(device)\n",
|
329 |
+
"\n",
|
330 |
+
" outputs = model(ids, attention_mask=mask, labels=labels)\n",
|
331 |
+
"\n",
|
332 |
+
" loss = outputs[0]\n",
|
333 |
+
" loss.backward()\n",
|
334 |
+
" optimizer.step()\n",
|
335 |
+
"model.eval()\n"
|
336 |
+
],
|
337 |
+
"metadata": {
|
338 |
+
"id": "mtMhE5_z8kw8"
|
339 |
+
},
|
340 |
+
"execution_count": null,
|
341 |
+
"outputs": []
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"source": [
|
346 |
+
"xtrain = [\"FUCK YOUR FILTHY MOTHER IN THE ASS, DRY!\"]\n",
|
347 |
+
"batch = tokenizer(xtrain, truncation=True, padding='max_length', return_tensors=\"pt\").to(device)\n",
|
348 |
+
"\n",
|
349 |
+
"with torch.no_grad():\n",
|
350 |
+
" outputs = model(**batch)\n",
|
351 |
+
" results = torch.sigmoid(outputs.logits)*100\n",
|
352 |
+
" print(results)\n",
|
353 |
+
"\n",
|
354 |
+
"model.save_pretrained(\"pretrained_model\")\n",
|
355 |
+
"tokenizer.save_pretrained(\"model_tokenizer\")"
|
356 |
+
],
|
357 |
+
"metadata": {
|
358 |
+
"id": "8T4UG8K8BvUn"
|
359 |
+
},
|
360 |
+
"execution_count": null,
|
361 |
+
"outputs": []
|
362 |
+
}
|
363 |
+
]
|
364 |
+
}
|