working trainer
Browse files- .ipynb_checkpoints/Copy_of_Copy_of_training-checkpoint.ipynb +0 -0
- Copy_of_Copy_of_training.ipynb +345 -0
- logs/1682300361.4426298/events.out.tfevents.1682300361.mint.371280.1 +0 -0
- logs/1682300884.6095285/events.out.tfevents.1682300884.mint.371280.3 +0 -0
- logs/1682300938.1223385/events.out.tfevents.1682300938.mint.371280.5 +0 -0
- logs/1682301013.2686887/events.out.tfevents.1682301013.mint.371280.7 +0 -0
- logs/events.out.tfevents.1682300361.mint.371280.0 +0 -0
- logs/events.out.tfevents.1682300884.mint.371280.2 +0 -0
- logs/events.out.tfevents.1682300938.mint.371280.4 +0 -0
- logs/events.out.tfevents.1682301013.mint.371280.6 +0 -0
- train.py +50 -55
- working_training.ipynb +601 -0
.ipynb_checkpoints/Copy_of_Copy_of_training-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Copy_of_Copy_of_training.ipynb
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "215a1aae",
|
7 |
+
"metadata": {
|
8 |
+
"id": "215a1aae"
|
9 |
+
},
|
10 |
+
"outputs": [
|
11 |
+
{
|
12 |
+
"name": "stderr",
|
13 |
+
"output_type": "stream",
|
14 |
+
"text": [
|
15 |
+
"2023-04-23 21:39:14.489766: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
16 |
+
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
17 |
+
"2023-04-23 21:39:15.104927: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
|
18 |
+
]
|
19 |
+
}
|
20 |
+
],
|
21 |
+
"source": [
|
22 |
+
"import torch\n",
|
23 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
24 |
+
"\n",
|
25 |
+
"import pandas as pd\n",
|
26 |
+
"\n",
|
27 |
+
"from transformers import BertTokenizerFast, BertForSequenceClassification\n",
|
28 |
+
"from transformers import Trainer, TrainingArguments"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 3,
|
34 |
+
"id": "J5Tlgp4tNd0U",
|
35 |
+
"metadata": {
|
36 |
+
"colab": {
|
37 |
+
"base_uri": "https://localhost:8080/"
|
38 |
+
},
|
39 |
+
"id": "J5Tlgp4tNd0U",
|
40 |
+
"outputId": "f2eef2ee-7d9d-4f5b-e35c-e6015e68f59e"
|
41 |
+
},
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"name": "stderr",
|
45 |
+
"output_type": "stream",
|
46 |
+
"text": [
|
47 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
|
48 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
49 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
50 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
|
51 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
52 |
+
]
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"model_name = \"bert-base-uncased\"\n",
|
57 |
+
"tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
|
58 |
+
"model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
|
59 |
+
"model = model.to(\"cuda:0\")\n",
|
60 |
+
"max_len = 200\n",
|
61 |
+
"\n",
|
62 |
+
"training_args = TrainingArguments(\n",
|
63 |
+
" output_dir=\"results\",\n",
|
64 |
+
" num_train_epochs=1,\n",
|
65 |
+
" per_device_train_batch_size=16,\n",
|
66 |
+
" per_device_eval_batch_size=64,\n",
|
67 |
+
" warmup_steps=500,\n",
|
68 |
+
" learning_rate=5e-5,\n",
|
69 |
+
" weight_decay=0.01,\n",
|
70 |
+
" logging_dir=\"./logs\",\n",
|
71 |
+
" logging_steps=10\n",
|
72 |
+
" )\n",
|
73 |
+
"\n",
|
74 |
+
"# dataset class that inherits from torch.utils.data.Dataset\n",
|
75 |
+
"\n",
|
76 |
+
" \n",
|
77 |
+
"class TokenizerDataset(Dataset):\n",
|
78 |
+
" def __init__(self, strings):\n",
|
79 |
+
" self.strings = strings\n",
|
80 |
+
" \n",
|
81 |
+
" def __getitem__(self, idx):\n",
|
82 |
+
" return self.strings[idx]\n",
|
83 |
+
" \n",
|
84 |
+
" def __len__(self):\n",
|
85 |
+
" return len(self.strings)\n",
|
86 |
+
" "
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 4,
|
92 |
+
"id": "9969c58c",
|
93 |
+
"metadata": {
|
94 |
+
"colab": {
|
95 |
+
"base_uri": "https://localhost:8080/"
|
96 |
+
},
|
97 |
+
"id": "9969c58c",
|
98 |
+
"outputId": "5933b10b-9ddb-4b67-b66b-589207bef2d3",
|
99 |
+
"scrolled": false
|
100 |
+
},
|
101 |
+
"outputs": [
|
102 |
+
{
|
103 |
+
"name": "stdout",
|
104 |
+
"output_type": "stream",
|
105 |
+
"text": [
|
106 |
+
" id comment_text \\\n",
|
107 |
+
"0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n",
|
108 |
+
"1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n",
|
109 |
+
"2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n",
|
110 |
+
"3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n",
|
111 |
+
"4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n",
|
112 |
+
"... ... ... \n",
|
113 |
+
"159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n",
|
114 |
+
"159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n",
|
115 |
+
"159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n",
|
116 |
+
"159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n",
|
117 |
+
"159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n",
|
118 |
+
"\n",
|
119 |
+
" toxic severe_toxic obscene threat insult identity_hate \n",
|
120 |
+
"0 0 0 0 0 0 0 \n",
|
121 |
+
"1 0 0 0 0 0 0 \n",
|
122 |
+
"2 0 0 0 0 0 0 \n",
|
123 |
+
"3 0 0 0 0 0 0 \n",
|
124 |
+
"4 0 0 0 0 0 0 \n",
|
125 |
+
"... ... ... ... ... ... ... \n",
|
126 |
+
"159566 0 0 0 0 0 0 \n",
|
127 |
+
"159567 0 0 0 0 0 0 \n",
|
128 |
+
"159568 0 0 0 0 0 0 \n",
|
129 |
+
"159569 0 0 0 0 0 0 \n",
|
130 |
+
"159570 0 0 0 0 0 0 \n",
|
131 |
+
"\n",
|
132 |
+
"[159571 rows x 8 columns]\n"
|
133 |
+
]
|
134 |
+
}
|
135 |
+
],
|
136 |
+
"source": [
|
137 |
+
"train_data = pd.read_csv(\"data/train.csv\")\n",
|
138 |
+
"print(train_data)\n",
|
139 |
+
"train_text = train_data[\"comment_text\"]\n",
|
140 |
+
"train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
|
141 |
+
" \"obscene\", \"threat\", \n",
|
142 |
+
" \"insult\", \"identity_hate\"]]\n",
|
143 |
+
"\n",
|
144 |
+
"test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
|
145 |
+
"test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
|
146 |
+
" \"toxic\", \"severe_toxic\", \n",
|
147 |
+
" \"obscene\", \"threat\", \n",
|
148 |
+
" \"insult\", \"identity_hate\"]]\n",
|
149 |
+
"\n",
|
150 |
+
"# data preprocessing\n",
|
151 |
+
"\n",
|
152 |
+
"\n",
|
153 |
+
"\n",
|
154 |
+
"train_text = train_text.values.tolist()\n",
|
155 |
+
"train_labels = train_labels.values.tolist()\n",
|
156 |
+
"test_text = test_text.values.tolist()\n",
|
157 |
+
"test_labels = test_labels.values.tolist()\n"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "code",
|
162 |
+
"execution_count": 10,
|
163 |
+
"id": "1n56TME9Njde",
|
164 |
+
"metadata": {
|
165 |
+
"id": "1n56TME9Njde"
|
166 |
+
},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"# prepare tokenizer and dataset\n",
|
170 |
+
"\n",
|
171 |
+
"class TweetDataset(Dataset):\n",
|
172 |
+
" def __init__(self, encodings, labels):\n",
|
173 |
+
" self.encodings = encodings\n",
|
174 |
+
" self.labels = labels\n",
|
175 |
+
" self.tok = tokenizer\n",
|
176 |
+
" \n",
|
177 |
+
" def __getitem__(self, idx):\n",
|
178 |
+
"# print(idx)\n",
|
179 |
+
" print(len(self.labels))\n",
|
180 |
+
" encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
|
181 |
+
" print(encoding.items())\n",
|
182 |
+
" item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
|
183 |
+
" item['labels'] = torch.tensor(self.labels[idx])\n",
|
184 |
+
"# print(item)\n",
|
185 |
+
" return item\n",
|
186 |
+
" \n",
|
187 |
+
" def __len__(self):\n",
|
188 |
+
" return len(self.labels)\n",
|
189 |
+
"\n",
|
190 |
+
"# no tokenizer\n",
|
191 |
+
"class TweetDataset2(Dataset):\n",
|
192 |
+
" def __init__(self, encodings, labels):\n",
|
193 |
+
" self.encodings = encodings\n",
|
194 |
+
" self.labels = labels\n",
|
195 |
+
" self.tok = tokenizer\n",
|
196 |
+
" \n",
|
197 |
+
" def __getitem__(self, idx):\n",
|
198 |
+
"# print(idx)\n",
|
199 |
+
" print(len(self.labels))\n",
|
200 |
+
" encoding = self.tok(self.encodings.strings[idx], truncation=True, padding=\"max_length\", max_length=max_len).to(\"cuda:0\")\n",
|
201 |
+
" print(encoding.items())\n",
|
202 |
+
" item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
|
203 |
+
" item['labels'] = torch.tensor(self.labels[idx])\n",
|
204 |
+
"# print(item)\n",
|
205 |
+
" return item\n",
|
206 |
+
" \n",
|
207 |
+
" def __len__(self):\n",
|
208 |
+
" return len(self.labels)\n",
|
209 |
+
"\n",
|
210 |
+
"\n",
|
211 |
+
"\n",
|
212 |
+
"\n",
|
213 |
+
"train_strings = TokenizerDataset(train_text)\n",
|
214 |
+
"test_strings = TokenizerDataset(test_text)\n",
|
215 |
+
"\n",
|
216 |
+
"train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
|
217 |
+
"test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
|
218 |
+
"\n",
|
219 |
+
"\n",
|
220 |
+
"\n",
|
221 |
+
"\n",
|
222 |
+
"train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
|
223 |
+
" max_length=200, pad_to_max_length=True, \\\n",
|
224 |
+
" truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
|
225 |
+
" ).to(\"cuda:0\")\n",
|
226 |
+
"test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
|
227 |
+
" max_length=200, pad_to_max_length=True, \\\n",
|
228 |
+
" truncation=True, return_token_type_ids=False, return_tensors='pt' \\\n",
|
229 |
+
" ).to(\"cuda:0\")\n",
|
230 |
+
"\n",
|
231 |
+
"# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
|
232 |
+
"# test_encodings = tokenizer(test_text, truncation=True, padding=True)"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"cell_type": "code",
|
237 |
+
"execution_count": 15,
|
238 |
+
"id": "4kwydz67qjW9",
|
239 |
+
"metadata": {
|
240 |
+
"colab": {
|
241 |
+
"base_uri": "https://localhost:8080/"
|
242 |
+
},
|
243 |
+
"id": "4kwydz67qjW9",
|
244 |
+
"outputId": "1653744e-69cf-46f8-a2d1-ffc3a3a4d58a"
|
245 |
+
},
|
246 |
+
"outputs": [
|
247 |
+
{
|
248 |
+
"name": "stdout",
|
249 |
+
"output_type": "stream",
|
250 |
+
"text": [
|
251 |
+
"159571\n",
|
252 |
+
"159571\n"
|
253 |
+
]
|
254 |
+
}
|
255 |
+
],
|
256 |
+
"source": [
|
257 |
+
"# no tokenizer\n",
|
258 |
+
"class TweetDataset3(Dataset):\n",
|
259 |
+
" def __init__(self, encodings, labels):\n",
|
260 |
+
" self.encodings = encodings\n",
|
261 |
+
" self.labels = labels\n",
|
262 |
+
" self.tok = tokenizer\n",
|
263 |
+
" \n",
|
264 |
+
" def __getitem__(self, idx):\n",
|
265 |
+
" print(idx)\n",
|
266 |
+
" item = { key: torch.tensor(val) for key, val in self.encodings.items() }\n",
|
267 |
+
" item['labels'] = torch.tensor(self.labels[idx])\n",
|
268 |
+
"# print(item)\n",
|
269 |
+
" return item\n",
|
270 |
+
" \n",
|
271 |
+
" def __len__(self):\n",
|
272 |
+
" return len(self.labels)\n",
|
273 |
+
"\n",
|
274 |
+
"\n",
|
275 |
+
"\n",
|
276 |
+
"train_dataset = TweetDataset3(train_encodings, train_labels)\n",
|
277 |
+
"test_dataset = TweetDataset3(test_encodings, test_labels)\n",
|
278 |
+
"\n",
|
279 |
+
"print(len(train_dataset.labels))\n",
|
280 |
+
"print(len(train_strings))\n",
|
281 |
+
"\n",
|
282 |
+
"\n",
|
283 |
+
"class MultilabelTrainer(Trainer):\n",
|
284 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
285 |
+
" labels = inputs.pop(\"labels\")\n",
|
286 |
+
" outputs = model(**inputs)\n",
|
287 |
+
" logits = outputs.logits\n",
|
288 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
289 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n",
|
290 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
291 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
292 |
+
"\n",
|
293 |
+
"\n",
|
294 |
+
"# training\n",
|
295 |
+
"trainer = MultilabelTrainer(\n",
|
296 |
+
" model=model, \n",
|
297 |
+
" args=training_args, \n",
|
298 |
+
" train_dataset=train_dataset, \n",
|
299 |
+
" eval_dataset=test_dataset\n",
|
300 |
+
" )"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": null,
|
306 |
+
"id": "VwsyMZg_tgTg",
|
307 |
+
"metadata": {
|
308 |
+
"colab": {
|
309 |
+
"base_uri": "https://localhost:8080/",
|
310 |
+
"height": 1000
|
311 |
+
},
|
312 |
+
"id": "VwsyMZg_tgTg",
|
313 |
+
"outputId": "6cf8f3aa-629e-4650-9bbd-dfeb11071ef7"
|
314 |
+
},
|
315 |
+
"outputs": [],
|
316 |
+
"source": [
|
317 |
+
"trainer.train()"
|
318 |
+
]
|
319 |
+
}
|
320 |
+
],
|
321 |
+
"metadata": {
|
322 |
+
"colab": {
|
323 |
+
"provenance": []
|
324 |
+
},
|
325 |
+
"kernelspec": {
|
326 |
+
"display_name": "Python 3 (ipykernel)",
|
327 |
+
"language": "python",
|
328 |
+
"name": "python3"
|
329 |
+
},
|
330 |
+
"language_info": {
|
331 |
+
"codemirror_mode": {
|
332 |
+
"name": "ipython",
|
333 |
+
"version": 3
|
334 |
+
},
|
335 |
+
"file_extension": ".py",
|
336 |
+
"mimetype": "text/x-python",
|
337 |
+
"name": "python",
|
338 |
+
"nbconvert_exporter": "python",
|
339 |
+
"pygments_lexer": "ipython3",
|
340 |
+
"version": "3.10.6"
|
341 |
+
}
|
342 |
+
},
|
343 |
+
"nbformat": 4,
|
344 |
+
"nbformat_minor": 5
|
345 |
+
}
|
logs/1682300361.4426298/events.out.tfevents.1682300361.mint.371280.1
ADDED
Binary file (5.8 kB). View file
|
|
logs/1682300884.6095285/events.out.tfevents.1682300884.mint.371280.3
ADDED
Binary file (5.8 kB). View file
|
|
logs/1682300938.1223385/events.out.tfevents.1682300938.mint.371280.5
ADDED
Binary file (5.8 kB). View file
|
|
logs/1682301013.2686887/events.out.tfevents.1682301013.mint.371280.7
ADDED
Binary file (5.8 kB). View file
|
|
logs/events.out.tfevents.1682300361.mint.371280.0
ADDED
Binary file (4.19 kB). View file
|
|
logs/events.out.tfevents.1682300884.mint.371280.2
ADDED
Binary file (4.19 kB). View file
|
|
logs/events.out.tfevents.1682300938.mint.371280.4
ADDED
Binary file (4.19 kB). View file
|
|
logs/events.out.tfevents.1682301013.mint.371280.6
ADDED
Binary file (4.19 kB). View file
|
|
train.py
CHANGED
@@ -6,11 +6,11 @@ import pandas as pd
|
|
6 |
from transformers import BertTokenizerFast, BertForSequenceClassification
|
7 |
from transformers import Trainer, TrainingArguments
|
8 |
|
9 |
-
|
10 |
|
11 |
model_name = "bert-base-uncased"
|
12 |
tokenizer = BertTokenizerFast.from_pretrained(model_name)
|
13 |
-
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)
|
14 |
max_len = 200
|
15 |
|
16 |
training_args = TrainingArguments(
|
@@ -26,20 +26,7 @@ training_args = TrainingArguments(
|
|
26 |
)
|
27 |
|
28 |
# dataset class that inherits from torch.utils.data.Dataset
|
29 |
-
|
30 |
-
def __init__(self, encodings, labels):
|
31 |
-
self.encodings = encodings
|
32 |
-
self.labels = labels
|
33 |
-
self.tok = tokenizer
|
34 |
-
|
35 |
-
def __getitem__(self, idx):
|
36 |
-
# encoding = self.tok(self.encodings[idx], truncation=True, padding="max_length", max_length=max_len)
|
37 |
-
item = { key: torch.tensor(val[idx]) for key, val in self.encoding.items() }
|
38 |
-
item['labels'] = torch.tensor(self.labels[idx])
|
39 |
-
return item
|
40 |
-
|
41 |
-
def __len__(self):
|
42 |
-
return len(self.labels)
|
43 |
|
44 |
class TokenizerDataset(Dataset):
|
45 |
def __init__(self, strings):
|
@@ -52,10 +39,8 @@ class TokenizerDataset(Dataset):
|
|
52 |
return len(self.strings)
|
53 |
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
train_data = pd.read_csv("data/train.csv")
|
|
|
59 |
train_text = train_data["comment_text"]
|
60 |
train_labels = train_data[["toxic", "severe_toxic",
|
61 |
"obscene", "threat",
|
@@ -77,9 +62,31 @@ test_text = test_text.values.tolist()
|
|
77 |
test_labels = test_labels.values.tolist()
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
-
# prepare tokenizer and dataset
|
83 |
|
84 |
train_strings = TokenizerDataset(train_text)
|
85 |
test_strings = TokenizerDataset(test_text)
|
@@ -99,45 +106,33 @@ test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)
|
|
99 |
# truncation=True, return_token_type_ids=False \
|
100 |
# )
|
101 |
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
f = open("traintokens.txt", 'a')
|
108 |
-
f.write(train_encodings)
|
109 |
-
f.write('\n\n\n\n\n')
|
110 |
-
f.close()
|
111 |
-
|
112 |
-
g = open("testtokens.txt", 'a')
|
113 |
-
g.write(test_encodings)
|
114 |
-
g.write('\n\n\n\n\n')
|
115 |
-
|
116 |
-
g.close()
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
# train_dataset = TweetDataset(train_encodings, train_labels)
|
121 |
-
# test_dataset = TweetDataset(test_encodings, test_labels)
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
# # training
|
128 |
-
# trainer = Trainer(
|
129 |
-
# model=model,
|
130 |
-
# args=training_args,
|
131 |
-
# train_dataset=train_dataset,
|
132 |
-
# eval_dataset=test_dataset
|
133 |
-
# )
|
134 |
-
|
135 |
-
|
136 |
-
# trainer.train()
|
137 |
-
|
138 |
|
|
|
|
|
139 |
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
|
|
|
6 |
from transformers import BertTokenizerFast, BertForSequenceClassification
|
7 |
from transformers import Trainer, TrainingArguments
|
8 |
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
model_name = "bert-base-uncased"
|
12 |
tokenizer = BertTokenizerFast.from_pretrained(model_name)
|
13 |
+
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)
|
14 |
max_len = 200
|
15 |
|
16 |
training_args = TrainingArguments(
|
|
|
26 |
)
|
27 |
|
28 |
# dataset class that inherits from torch.utils.data.Dataset
|
29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
class TokenizerDataset(Dataset):
|
32 |
def __init__(self, strings):
|
|
|
39 |
return len(self.strings)
|
40 |
|
41 |
|
|
|
|
|
|
|
42 |
train_data = pd.read_csv("data/train.csv")
|
43 |
+
print(train_data)
|
44 |
train_text = train_data["comment_text"]
|
45 |
train_labels = train_data[["toxic", "severe_toxic",
|
46 |
"obscene", "threat",
|
|
|
62 |
test_labels = test_labels.values.tolist()
|
63 |
|
64 |
|
65 |
+
# prepare tokenizer and dataset
|
66 |
+
|
67 |
+
class TweetDataset(Dataset):
|
68 |
+
def __init__(self, encodings, labels):
|
69 |
+
self.encodings = encodings
|
70 |
+
self.labels = labels
|
71 |
+
self.tok = tokenizer
|
72 |
+
|
73 |
+
def __getitem__(self, idx):
|
74 |
+
print(idx)
|
75 |
+
# print(len(self.labels))
|
76 |
+
encoding = self.tok(self.encodings.strings[idx], truncation=True,
|
77 |
+
padding="max_length", max_length=max_len)
|
78 |
+
# print(encoding.items())
|
79 |
+
item = { key: torch.tensor(val) for key, val in encoding.items() }
|
80 |
+
item['labels'] = torch.tensor(self.labels[idx])
|
81 |
+
# print(item)
|
82 |
+
return item
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.labels)
|
86 |
+
|
87 |
+
|
88 |
|
89 |
|
|
|
90 |
|
91 |
train_strings = TokenizerDataset(train_text)
|
92 |
test_strings = TokenizerDataset(test_text)
|
|
|
106 |
# truncation=True, return_token_type_ids=False \
|
107 |
# )
|
108 |
|
109 |
+
# train_encodings = tokenizer(train_text, truncation=True, padding=True)
|
110 |
+
# test_encodings = tokenizer(test_text, truncation=True, padding=True)
|
111 |
|
112 |
+
train_dataset = TweetDataset(train_strings, train_labels)
|
113 |
+
test_dataset = TweetDataset(test_strings, test_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
+
print(len(train_dataset.labels))
|
116 |
+
print(len(train_strings))
|
117 |
|
118 |
|
119 |
+
class MultilabelTrainer(Trainer):
|
120 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
121 |
+
labels = inputs.pop("labels")
|
122 |
+
outputs = model(**inputs)
|
123 |
+
logits = outputs.logits
|
124 |
+
loss_fct = torch.nn.BCEWithLogitsLoss()
|
125 |
+
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
126 |
+
labels.float().view(-1, self.model.config.num_labels))
|
127 |
+
return (loss, outputs) if return_outputs else loss
|
128 |
|
129 |
|
130 |
+
# training
|
131 |
+
trainer = MultilabelTrainer(
|
132 |
+
model=model,
|
133 |
+
args=training_args,
|
134 |
+
train_dataset=train_dataset,
|
135 |
+
eval_dataset=test_dataset
|
136 |
+
)
|
137 |
|
138 |
+
trainer.train()
|
working_training.ipynb
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "215a1aae",
|
7 |
+
"metadata": {
|
8 |
+
"id": "215a1aae"
|
9 |
+
},
|
10 |
+
"outputs": [],
|
11 |
+
"source": [
|
12 |
+
"import torch\n",
|
13 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
14 |
+
"\n",
|
15 |
+
"# import torch_xla\n",
|
16 |
+
"# import torch_xla.core.xla_model as xm\n",
|
17 |
+
"\n",
|
18 |
+
"import pandas as pd\n",
|
19 |
+
"\n",
|
20 |
+
"from transformers import BertTokenizerFast, BertForSequenceClassification\n",
|
21 |
+
"from transformers import Trainer, TrainingArguments"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"source": [
|
27 |
+
"device = \"cuda:0\"\n",
|
28 |
+
"\n",
|
29 |
+
"model_name = \"bert-base-uncased\"\n",
|
30 |
+
"tokenizer = BertTokenizerFast.from_pretrained(model_name)\n",
|
31 |
+
"model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6).to(device)\n",
|
32 |
+
"max_len = 200\n",
|
33 |
+
"\n",
|
34 |
+
"training_args = TrainingArguments(\n",
|
35 |
+
" output_dir=\"results\",\n",
|
36 |
+
" num_train_epochs=1,\n",
|
37 |
+
" per_device_train_batch_size=16,\n",
|
38 |
+
" per_device_eval_batch_size=64,\n",
|
39 |
+
" warmup_steps=500,\n",
|
40 |
+
" learning_rate=5e-5,\n",
|
41 |
+
" weight_decay=0.01,\n",
|
42 |
+
" logging_dir=\"./logs\",\n",
|
43 |
+
" logging_steps=10\n",
|
44 |
+
" )\n",
|
45 |
+
"\n",
|
46 |
+
"# dataset class that inherits from torch.utils.data.Dataset\n",
|
47 |
+
"\n",
|
48 |
+
" \n",
|
49 |
+
"class TokenizerDataset(Dataset):\n",
|
50 |
+
" def __init__(self, strings):\n",
|
51 |
+
" self.strings = strings\n",
|
52 |
+
" \n",
|
53 |
+
" def __getitem__(self, idx):\n",
|
54 |
+
" return self.strings[idx]\n",
|
55 |
+
" \n",
|
56 |
+
" def __len__(self):\n",
|
57 |
+
" return len(self.strings)\n",
|
58 |
+
" "
|
59 |
+
],
|
60 |
+
"metadata": {
|
61 |
+
"id": "J5Tlgp4tNd0U",
|
62 |
+
"outputId": "5d45330f-ec42-4766-8bf6-85ba08af7c3b",
|
63 |
+
"colab": {
|
64 |
+
"base_uri": "https://localhost:8080/"
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"id": "J5Tlgp4tNd0U",
|
68 |
+
"execution_count": null,
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"output_type": "stream",
|
72 |
+
"name": "stderr",
|
73 |
+
"text": [
|
74 |
+
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
|
75 |
+
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
76 |
+
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
77 |
+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
|
78 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
79 |
+
]
|
80 |
+
}
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": null,
|
86 |
+
"id": "9969c58c",
|
87 |
+
"metadata": {
|
88 |
+
"scrolled": false,
|
89 |
+
"id": "9969c58c",
|
90 |
+
"colab": {
|
91 |
+
"base_uri": "https://localhost:8080/"
|
92 |
+
},
|
93 |
+
"outputId": "cc7363d4-0ad4-4b58-baae-72efe63c7aad"
|
94 |
+
},
|
95 |
+
"outputs": [
|
96 |
+
{
|
97 |
+
"output_type": "stream",
|
98 |
+
"name": "stdout",
|
99 |
+
"text": [
|
100 |
+
" id comment_text \\\n",
|
101 |
+
"0 0000997932d777bf Explanation\\nWhy the edits made under my usern... \n",
|
102 |
+
"1 000103f0d9cfb60f D'aww! He matches this background colour I'm s... \n",
|
103 |
+
"2 000113f07ec002fd Hey man, I'm really not trying to edit war. It... \n",
|
104 |
+
"3 0001b41b1c6bb37e \"\\nMore\\nI can't make any real suggestions on ... \n",
|
105 |
+
"4 0001d958c54c6e35 You, sir, are my hero. Any chance you remember... \n",
|
106 |
+
"... ... ... \n",
|
107 |
+
"159566 ffe987279560d7ff \":::::And for the second time of asking, when ... \n",
|
108 |
+
"159567 ffea4adeee384e90 You should be ashamed of yourself \\n\\nThat is ... \n",
|
109 |
+
"159568 ffee36eab5c267c9 Spitzer \\n\\nUmm, theres no actual article for ... \n",
|
110 |
+
"159569 fff125370e4aaaf3 And it looks like it was actually you who put ... \n",
|
111 |
+
"159570 fff46fc426af1f9a \"\\nAnd ... I really don't think you understand... \n",
|
112 |
+
"\n",
|
113 |
+
" toxic severe_toxic obscene threat insult identity_hate \n",
|
114 |
+
"0 0 0 0 0 0 0 \n",
|
115 |
+
"1 0 0 0 0 0 0 \n",
|
116 |
+
"2 0 0 0 0 0 0 \n",
|
117 |
+
"3 0 0 0 0 0 0 \n",
|
118 |
+
"4 0 0 0 0 0 0 \n",
|
119 |
+
"... ... ... ... ... ... ... \n",
|
120 |
+
"159566 0 0 0 0 0 0 \n",
|
121 |
+
"159567 0 0 0 0 0 0 \n",
|
122 |
+
"159568 0 0 0 0 0 0 \n",
|
123 |
+
"159569 0 0 0 0 0 0 \n",
|
124 |
+
"159570 0 0 0 0 0 0 \n",
|
125 |
+
"\n",
|
126 |
+
"[159571 rows x 8 columns]\n"
|
127 |
+
]
|
128 |
+
}
|
129 |
+
],
|
130 |
+
"source": [
|
131 |
+
"train_data = pd.read_csv(\"data/train.csv\")\n",
|
132 |
+
"print(train_data)\n",
|
133 |
+
"train_text = train_data[\"comment_text\"]\n",
|
134 |
+
"train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
|
135 |
+
" \"obscene\", \"threat\", \n",
|
136 |
+
" \"insult\", \"identity_hate\"]]\n",
|
137 |
+
"\n",
|
138 |
+
"test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"]\n",
|
139 |
+
"test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
|
140 |
+
" \"toxic\", \"severe_toxic\", \n",
|
141 |
+
" \"obscene\", \"threat\", \n",
|
142 |
+
" \"insult\", \"identity_hate\"]]\n",
|
143 |
+
"\n",
|
144 |
+
"# data preprocessing\n",
|
145 |
+
"\n",
|
146 |
+
"\n",
|
147 |
+
"\n",
|
148 |
+
"train_text = train_text.values.tolist()\n",
|
149 |
+
"train_labels = train_labels.values.tolist()\n",
|
150 |
+
"test_text = test_text.values.tolist()\n",
|
151 |
+
"test_labels = test_labels.values.tolist()\n"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"source": [
|
157 |
+
"# prepare tokenizer and dataset\n",
|
158 |
+
"\n",
|
159 |
+
"class TweetDataset(Dataset):\n",
|
160 |
+
" def __init__(self, encodings, labels):\n",
|
161 |
+
" self.encodings = encodings\n",
|
162 |
+
" self.labels = labels\n",
|
163 |
+
" self.tok = tokenizer\n",
|
164 |
+
" \n",
|
165 |
+
" def __getitem__(self, idx):\n",
|
166 |
+
" # print(idx)\n",
|
167 |
+
" # print(len(self.labels))\n",
|
168 |
+
" encoding = self.tok(self.encodings.strings[idx], truncation=True, \n",
|
169 |
+
" padding=\"max_length\", max_length=max_len)\n",
|
170 |
+
" # print(encoding.items())\n",
|
171 |
+
" item = { key: torch.tensor(val) for key, val in encoding.items() }\n",
|
172 |
+
" item['labels'] = torch.tensor(self.labels[idx])\n",
|
173 |
+
" # print(item)\n",
|
174 |
+
" return item\n",
|
175 |
+
" \n",
|
176 |
+
" def __len__(self):\n",
|
177 |
+
" return len(self.labels)\n",
|
178 |
+
"\n",
|
179 |
+
"\n",
|
180 |
+
"\n",
|
181 |
+
"\n",
|
182 |
+
"\n",
|
183 |
+
"train_strings = TokenizerDataset(train_text)\n",
|
184 |
+
"test_strings = TokenizerDataset(test_text)\n",
|
185 |
+
"\n",
|
186 |
+
"train_dataloader = DataLoader(train_strings, batch_size=16, shuffle=True)\n",
|
187 |
+
"test_dataloader = DataLoader(test_strings, batch_size=16, shuffle=True)\n",
|
188 |
+
"\n",
|
189 |
+
"\n",
|
190 |
+
"\n",
|
191 |
+
"\n",
|
192 |
+
"# train_encodings = tokenizer.batch_encode_plus(train_text, \\\n",
|
193 |
+
"# max_length=200, pad_to_max_length=True, \\\n",
|
194 |
+
"# truncation=True, return_token_type_ids=False)\n",
|
195 |
+
"# # return_tensors='pt')\n",
|
196 |
+
"# test_encodings = tokenizer.batch_encode_plus(test_text, \\\n",
|
197 |
+
"# max_length=200, pad_to_max_length=True, \\\n",
|
198 |
+
"# truncation=True, return_token_type_ids=False)\n",
|
199 |
+
"# # return_tensors='pt')\n",
|
200 |
+
"\n",
|
201 |
+
"# train_encodings = tokenizer(train_text, truncation=True, padding=True)\n",
|
202 |
+
"# test_encodings = tokenizer(test_text, truncation=True, padding=True)"
|
203 |
+
],
|
204 |
+
"metadata": {
|
205 |
+
"id": "1n56TME9Njde"
|
206 |
+
},
|
207 |
+
"id": "1n56TME9Njde",
|
208 |
+
"execution_count": null,
|
209 |
+
"outputs": []
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"source": [
|
214 |
+
"train_dataset = TweetDataset(train_strings, train_labels)\n",
|
215 |
+
"test_dataset = TweetDataset(test_strings, test_labels)\n",
|
216 |
+
"\n",
|
217 |
+
"print(len(train_dataset.labels))\n",
|
218 |
+
"print(len(train_strings))\n",
|
219 |
+
"\n",
|
220 |
+
"\n",
|
221 |
+
"class MultilabelTrainer(Trainer):\n",
|
222 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
223 |
+
" labels = inputs.pop(\"labels\")\n",
|
224 |
+
" outputs = model(**inputs)\n",
|
225 |
+
" logits = outputs.logits\n",
|
226 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
227 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels), \n",
|
228 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
229 |
+
" return (loss, outputs) if return_outputs else loss\n",
|
230 |
+
"\n",
|
231 |
+
"\n",
|
232 |
+
"# training\n",
|
233 |
+
"trainer = MultilabelTrainer(\n",
|
234 |
+
" model=model, \n",
|
235 |
+
" args=training_args, \n",
|
236 |
+
" train_dataset=train_dataset, \n",
|
237 |
+
" eval_dataset=test_dataset\n",
|
238 |
+
" )"
|
239 |
+
],
|
240 |
+
"metadata": {
|
241 |
+
"id": "4kwydz67qjW9",
|
242 |
+
"colab": {
|
243 |
+
"base_uri": "https://localhost:8080/"
|
244 |
+
},
|
245 |
+
"outputId": "8405ba5b-6ef8-4bb1-87c0-637510e11cdc"
|
246 |
+
},
|
247 |
+
"id": "4kwydz67qjW9",
|
248 |
+
"execution_count": null,
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"output_type": "stream",
|
252 |
+
"name": "stdout",
|
253 |
+
"text": [
|
254 |
+
"159571\n",
|
255 |
+
"159571\n"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"source": [
|
263 |
+
"trainer.train()"
|
264 |
+
],
|
265 |
+
"metadata": {
|
266 |
+
"id": "VwsyMZg_tgTg",
|
267 |
+
"outputId": "2153bf25-56d5-4b1f-a24a-8e2f4731638e",
|
268 |
+
"colab": {
|
269 |
+
"base_uri": "https://localhost:8080/",
|
270 |
+
"height": 1000
|
271 |
+
}
|
272 |
+
},
|
273 |
+
"id": "VwsyMZg_tgTg",
|
274 |
+
"execution_count": null,
|
275 |
+
"outputs": [
|
276 |
+
{
|
277 |
+
"output_type": "stream",
|
278 |
+
"name": "stderr",
|
279 |
+
"text": [
|
280 |
+
"/usr/local/lib/python3.9/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
281 |
+
" warnings.warn(\n"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"output_type": "display_data",
|
286 |
+
"data": {
|
287 |
+
"text/plain": [
|
288 |
+
"<IPython.core.display.HTML object>"
|
289 |
+
],
|
290 |
+
"text/html": [
|
291 |
+
"\n",
|
292 |
+
" <div>\n",
|
293 |
+
" \n",
|
294 |
+
" <progress value='582' max='9974' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
295 |
+
" [ 582/9974 05:37 < 1:30:57, 1.72 it/s, Epoch 0.06/1]\n",
|
296 |
+
" </div>\n",
|
297 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
298 |
+
" <thead>\n",
|
299 |
+
" <tr style=\"text-align: left;\">\n",
|
300 |
+
" <th>Step</th>\n",
|
301 |
+
" <th>Training Loss</th>\n",
|
302 |
+
" </tr>\n",
|
303 |
+
" </thead>\n",
|
304 |
+
" <tbody>\n",
|
305 |
+
" <tr>\n",
|
306 |
+
" <td>10</td>\n",
|
307 |
+
" <td>0.695800</td>\n",
|
308 |
+
" </tr>\n",
|
309 |
+
" <tr>\n",
|
310 |
+
" <td>20</td>\n",
|
311 |
+
" <td>0.674200</td>\n",
|
312 |
+
" </tr>\n",
|
313 |
+
" <tr>\n",
|
314 |
+
" <td>30</td>\n",
|
315 |
+
" <td>0.631900</td>\n",
|
316 |
+
" </tr>\n",
|
317 |
+
" <tr>\n",
|
318 |
+
" <td>40</td>\n",
|
319 |
+
" <td>0.570600</td>\n",
|
320 |
+
" </tr>\n",
|
321 |
+
" <tr>\n",
|
322 |
+
" <td>50</td>\n",
|
323 |
+
" <td>0.541100</td>\n",
|
324 |
+
" </tr>\n",
|
325 |
+
" <tr>\n",
|
326 |
+
" <td>60</td>\n",
|
327 |
+
" <td>0.500300</td>\n",
|
328 |
+
" </tr>\n",
|
329 |
+
" <tr>\n",
|
330 |
+
" <td>70</td>\n",
|
331 |
+
" <td>0.440800</td>\n",
|
332 |
+
" </tr>\n",
|
333 |
+
" <tr>\n",
|
334 |
+
" <td>80</td>\n",
|
335 |
+
" <td>0.405400</td>\n",
|
336 |
+
" </tr>\n",
|
337 |
+
" <tr>\n",
|
338 |
+
" <td>90</td>\n",
|
339 |
+
" <td>0.336200</td>\n",
|
340 |
+
" </tr>\n",
|
341 |
+
" <tr>\n",
|
342 |
+
" <td>100</td>\n",
|
343 |
+
" <td>0.285000</td>\n",
|
344 |
+
" </tr>\n",
|
345 |
+
" <tr>\n",
|
346 |
+
" <td>110</td>\n",
|
347 |
+
" <td>0.232400</td>\n",
|
348 |
+
" </tr>\n",
|
349 |
+
" <tr>\n",
|
350 |
+
" <td>120</td>\n",
|
351 |
+
" <td>0.239500</td>\n",
|
352 |
+
" </tr>\n",
|
353 |
+
" <tr>\n",
|
354 |
+
" <td>130</td>\n",
|
355 |
+
" <td>0.197300</td>\n",
|
356 |
+
" </tr>\n",
|
357 |
+
" <tr>\n",
|
358 |
+
" <td>140</td>\n",
|
359 |
+
" <td>0.196700</td>\n",
|
360 |
+
" </tr>\n",
|
361 |
+
" <tr>\n",
|
362 |
+
" <td>150</td>\n",
|
363 |
+
" <td>0.143900</td>\n",
|
364 |
+
" </tr>\n",
|
365 |
+
" <tr>\n",
|
366 |
+
" <td>160</td>\n",
|
367 |
+
" <td>0.153700</td>\n",
|
368 |
+
" </tr>\n",
|
369 |
+
" <tr>\n",
|
370 |
+
" <td>170</td>\n",
|
371 |
+
" <td>0.098200</td>\n",
|
372 |
+
" </tr>\n",
|
373 |
+
" <tr>\n",
|
374 |
+
" <td>180</td>\n",
|
375 |
+
" <td>0.129700</td>\n",
|
376 |
+
" </tr>\n",
|
377 |
+
" <tr>\n",
|
378 |
+
" <td>190</td>\n",
|
379 |
+
" <td>0.094500</td>\n",
|
380 |
+
" </tr>\n",
|
381 |
+
" <tr>\n",
|
382 |
+
" <td>200</td>\n",
|
383 |
+
" <td>0.104400</td>\n",
|
384 |
+
" </tr>\n",
|
385 |
+
" <tr>\n",
|
386 |
+
" <td>210</td>\n",
|
387 |
+
" <td>0.119000</td>\n",
|
388 |
+
" </tr>\n",
|
389 |
+
" <tr>\n",
|
390 |
+
" <td>220</td>\n",
|
391 |
+
" <td>0.081700</td>\n",
|
392 |
+
" </tr>\n",
|
393 |
+
" <tr>\n",
|
394 |
+
" <td>230</td>\n",
|
395 |
+
" <td>0.081800</td>\n",
|
396 |
+
" </tr>\n",
|
397 |
+
" <tr>\n",
|
398 |
+
" <td>240</td>\n",
|
399 |
+
" <td>0.079700</td>\n",
|
400 |
+
" </tr>\n",
|
401 |
+
" <tr>\n",
|
402 |
+
" <td>250</td>\n",
|
403 |
+
" <td>0.077800</td>\n",
|
404 |
+
" </tr>\n",
|
405 |
+
" <tr>\n",
|
406 |
+
" <td>260</td>\n",
|
407 |
+
" <td>0.093200</td>\n",
|
408 |
+
" </tr>\n",
|
409 |
+
" <tr>\n",
|
410 |
+
" <td>270</td>\n",
|
411 |
+
" <td>0.066400</td>\n",
|
412 |
+
" </tr>\n",
|
413 |
+
" <tr>\n",
|
414 |
+
" <td>280</td>\n",
|
415 |
+
" <td>0.064000</td>\n",
|
416 |
+
" </tr>\n",
|
417 |
+
" <tr>\n",
|
418 |
+
" <td>290</td>\n",
|
419 |
+
" <td>0.074000</td>\n",
|
420 |
+
" </tr>\n",
|
421 |
+
" <tr>\n",
|
422 |
+
" <td>300</td>\n",
|
423 |
+
" <td>0.084200</td>\n",
|
424 |
+
" </tr>\n",
|
425 |
+
" <tr>\n",
|
426 |
+
" <td>310</td>\n",
|
427 |
+
" <td>0.064300</td>\n",
|
428 |
+
" </tr>\n",
|
429 |
+
" <tr>\n",
|
430 |
+
" <td>320</td>\n",
|
431 |
+
" <td>0.082100</td>\n",
|
432 |
+
" </tr>\n",
|
433 |
+
" <tr>\n",
|
434 |
+
" <td>330</td>\n",
|
435 |
+
" <td>0.057900</td>\n",
|
436 |
+
" </tr>\n",
|
437 |
+
" <tr>\n",
|
438 |
+
" <td>340</td>\n",
|
439 |
+
" <td>0.065000</td>\n",
|
440 |
+
" </tr>\n",
|
441 |
+
" <tr>\n",
|
442 |
+
" <td>350</td>\n",
|
443 |
+
" <td>0.072900</td>\n",
|
444 |
+
" </tr>\n",
|
445 |
+
" <tr>\n",
|
446 |
+
" <td>360</td>\n",
|
447 |
+
" <td>0.064500</td>\n",
|
448 |
+
" </tr>\n",
|
449 |
+
" <tr>\n",
|
450 |
+
" <td>370</td>\n",
|
451 |
+
" <td>0.064300</td>\n",
|
452 |
+
" </tr>\n",
|
453 |
+
" <tr>\n",
|
454 |
+
" <td>380</td>\n",
|
455 |
+
" <td>0.071900</td>\n",
|
456 |
+
" </tr>\n",
|
457 |
+
" <tr>\n",
|
458 |
+
" <td>390</td>\n",
|
459 |
+
" <td>0.044600</td>\n",
|
460 |
+
" </tr>\n",
|
461 |
+
" <tr>\n",
|
462 |
+
" <td>400</td>\n",
|
463 |
+
" <td>0.059300</td>\n",
|
464 |
+
" </tr>\n",
|
465 |
+
" <tr>\n",
|
466 |
+
" <td>410</td>\n",
|
467 |
+
" <td>0.063000</td>\n",
|
468 |
+
" </tr>\n",
|
469 |
+
" <tr>\n",
|
470 |
+
" <td>420</td>\n",
|
471 |
+
" <td>0.082400</td>\n",
|
472 |
+
" </tr>\n",
|
473 |
+
" <tr>\n",
|
474 |
+
" <td>430</td>\n",
|
475 |
+
" <td>0.070100</td>\n",
|
476 |
+
" </tr>\n",
|
477 |
+
" <tr>\n",
|
478 |
+
" <td>440</td>\n",
|
479 |
+
" <td>0.042700</td>\n",
|
480 |
+
" </tr>\n",
|
481 |
+
" <tr>\n",
|
482 |
+
" <td>450</td>\n",
|
483 |
+
" <td>0.089500</td>\n",
|
484 |
+
" </tr>\n",
|
485 |
+
" <tr>\n",
|
486 |
+
" <td>460</td>\n",
|
487 |
+
" <td>0.061400</td>\n",
|
488 |
+
" </tr>\n",
|
489 |
+
" <tr>\n",
|
490 |
+
" <td>470</td>\n",
|
491 |
+
" <td>0.097300</td>\n",
|
492 |
+
" </tr>\n",
|
493 |
+
" <tr>\n",
|
494 |
+
" <td>480</td>\n",
|
495 |
+
" <td>0.062700</td>\n",
|
496 |
+
" </tr>\n",
|
497 |
+
" <tr>\n",
|
498 |
+
" <td>490</td>\n",
|
499 |
+
" <td>0.067800</td>\n",
|
500 |
+
" </tr>\n",
|
501 |
+
" <tr>\n",
|
502 |
+
" <td>500</td>\n",
|
503 |
+
" <td>0.083300</td>\n",
|
504 |
+
" </tr>\n",
|
505 |
+
" <tr>\n",
|
506 |
+
" <td>510</td>\n",
|
507 |
+
" <td>0.083500</td>\n",
|
508 |
+
" </tr>\n",
|
509 |
+
" <tr>\n",
|
510 |
+
" <td>520</td>\n",
|
511 |
+
" <td>0.053300</td>\n",
|
512 |
+
" </tr>\n",
|
513 |
+
" <tr>\n",
|
514 |
+
" <td>530</td>\n",
|
515 |
+
" <td>0.045400</td>\n",
|
516 |
+
" </tr>\n",
|
517 |
+
" <tr>\n",
|
518 |
+
" <td>540</td>\n",
|
519 |
+
" <td>0.052300</td>\n",
|
520 |
+
" </tr>\n",
|
521 |
+
" <tr>\n",
|
522 |
+
" <td>550</td>\n",
|
523 |
+
" <td>0.075300</td>\n",
|
524 |
+
" </tr>\n",
|
525 |
+
" <tr>\n",
|
526 |
+
" <td>560</td>\n",
|
527 |
+
" <td>0.069000</td>\n",
|
528 |
+
" </tr>\n",
|
529 |
+
" <tr>\n",
|
530 |
+
" <td>570</td>\n",
|
531 |
+
" <td>0.084800</td>\n",
|
532 |
+
" </tr>\n",
|
533 |
+
" <tr>\n",
|
534 |
+
" <td>580</td>\n",
|
535 |
+
" <td>0.028800</td>\n",
|
536 |
+
" </tr>\n",
|
537 |
+
" </tbody>\n",
|
538 |
+
"</table><p>"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
"metadata": {}
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"output_type": "error",
|
545 |
+
"ename": "KeyboardInterrupt",
|
546 |
+
"evalue": "ignored",
|
547 |
+
"traceback": [
|
548 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
549 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
550 |
+
"\u001b[0;32m<ipython-input-6-3435b262f1ae>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
551 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1660\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_inner_training_loop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_find_batch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1661\u001b[0m )\n\u001b[0;32m-> 1662\u001b[0;31m return inner_training_loop(\n\u001b[0m\u001b[1;32m 1663\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1664\u001b[0m \u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresume_from_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
552 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36m_inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1927\u001b[0m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1928\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1929\u001b[0;31m \u001b[0mtr_loss_step\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1930\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1931\u001b[0m if (\n",
|
553 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2715\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdeepspeed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2717\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2718\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2719\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
554 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
|
555 |
+
"\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
|
556 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
557 |
+
]
|
558 |
+
}
|
559 |
+
]
|
560 |
+
},
|
561 |
+
{
|
562 |
+
"cell_type": "code",
|
563 |
+
"source": [
|
564 |
+
"!nvidia-smi"
|
565 |
+
],
|
566 |
+
"metadata": {
|
567 |
+
"id": "EJPePRRQG1QK"
|
568 |
+
},
|
569 |
+
"id": "EJPePRRQG1QK",
|
570 |
+
"execution_count": null,
|
571 |
+
"outputs": []
|
572 |
+
}
|
573 |
+
],
|
574 |
+
"metadata": {
|
575 |
+
"kernelspec": {
|
576 |
+
"display_name": "Python 3 (ipykernel)",
|
577 |
+
"language": "python",
|
578 |
+
"name": "python3"
|
579 |
+
},
|
580 |
+
"language_info": {
|
581 |
+
"codemirror_mode": {
|
582 |
+
"name": "ipython",
|
583 |
+
"version": 3
|
584 |
+
},
|
585 |
+
"file_extension": ".py",
|
586 |
+
"mimetype": "text/x-python",
|
587 |
+
"name": "python",
|
588 |
+
"nbconvert_exporter": "python",
|
589 |
+
"pygments_lexer": "ipython3",
|
590 |
+
"version": "3.10.6"
|
591 |
+
},
|
592 |
+
"colab": {
|
593 |
+
"provenance": [],
|
594 |
+
"gpuType": "T4"
|
595 |
+
},
|
596 |
+
"accelerator": "GPU",
|
597 |
+
"gpuClass": "standard"
|
598 |
+
},
|
599 |
+
"nbformat": 4,
|
600 |
+
"nbformat_minor": 5
|
601 |
+
}
|