jbraha commited on
Commit
2c55221
·
1 Parent(s): d809919

'mint autosave'

Browse files
.ipynb_checkpoints/training-checkpoint.ipynb ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "215a1aae",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2023-04-23 12:34:45.188102: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
14
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
15
+ "2023-04-23 12:34:45.742757: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "from torch.utils.data import Dataset\n",
22
+ "\n",
23
+ "import pandas as pd\n",
24
+ "# import numpy as np\n",
25
+ "\n",
26
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
27
+ "from transformers import Trainer, TrainingArguments"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 10,
33
+ "id": "9969c58c",
34
+ "metadata": {
35
+ "scrolled": false
36
+ },
37
+ "outputs": [
38
+ {
39
+ "name": "stderr",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "IOPub data rate exceeded.\n",
43
+ "The notebook server will temporarily stop sending output\n",
44
+ "to the client in order to avoid crashing it.\n",
45
+ "To change this limit, set the config variable\n",
46
+ "`--NotebookApp.iopub_data_rate_limit`.\n",
47
+ "\n",
48
+ "Current values:\n",
49
+ "NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
50
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
51
+ "\n",
52
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (631 > 512). Running this sequence through the model will result in indexing errors\n"
53
+ ]
54
+ },
55
+ {
56
+ "ename": "ValueError",
57
+ "evalue": "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).",
58
+ "output_type": "error",
59
+ "traceback": [
60
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
61
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
62
+ "\u001b[0;32m/tmp/ipykernel_325077/677523904.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtrain_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mtest_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTweetDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_encodings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
63
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2536\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_in_target_context_manager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2537\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_input_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2538\u001b[0;31m \u001b[0mencodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_one\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_pair\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext_pair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2539\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtext_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2540\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_target_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
64
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m_call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2595\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_is_valid_text_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2596\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 2597\u001b[0m \u001b[0;34m\"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2598\u001b[0m \u001b[0;34m\"or `List[List[str]]` (batch of pretokenized examples).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
65
+ "\u001b[0;31mValueError\u001b[0m: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
66
+ ]
67
+ }
68
+ ],
69
+ "source": [
70
+ "model_name = \"bert-base-uncased\"\n",
71
+ "\n",
72
+ "# dataset class that inherits from torch.utils.data.Dataset\n",
73
+ "class TweetDataset(Dataset):\n",
74
+ " def __init__(self, encodings, labels):\n",
75
+ " self.encodings = encodings\n",
76
+ " self.labels = labels\n",
77
+ " \n",
78
+ " def __getitem__(self, idx):\n",
79
+ " item = { key: torch.tensor(val[idx]) for key, val in self.encodings.items() }\n",
80
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
81
+ " return item\n",
82
+ " \n",
83
+ " def __len__(self):\n",
84
+ " return len(self.labels)\n",
85
+ " \n",
86
+ "\n",
87
+ "\n",
88
+ "train_data = pd.read_csv(\"data/train.csv\")\n",
89
+ "train_text = train_data[\"comment_text\"].values.tolist()\n",
90
+ "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
91
+ " \"obscene\", \"threat\", \n",
92
+ " \"insult\", \"identity_hate\"]].values.tolist()\n",
93
+ "\n",
94
+ "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"].values.tolist()\n",
95
+ "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
96
+ " \"toxic\", \"severe_toxic\", \n",
97
+ " \"obscene\", \"threat\", \n",
98
+ " \"insult\", \"identity_hate\"]].values.tolist()\n",
99
+ "\n",
100
+ "\n",
101
+ "# prepare tokenizer and dataset\n",
102
+ "\n",
103
+ "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
104
+ "\n",
105
+ "print(train_text)\n",
106
+ "\n",
107
+ "\n",
108
+ "train_encodings = tokenizer(train_text)\n",
109
+ "test_encodings = tokenizer(test_text)\n",
110
+ "\n",
111
+ "train_dataset = TweetDataset(train_encodings, train_labels)\n",
112
+ "test_dataset = TweetDataset(test_encodings, test_labels)\n",
113
+ "\n",
114
+ "\n",
115
+ "# training\n",
116
+ "\n",
117
+ "\n",
118
+ "training_args = TrainingArguments(\n",
119
+ " output_dir=\"results\",\n",
120
+ " num_train_epochs=2,\n",
121
+ " per_device_train_batch_size=16,\n",
122
+ " per_device_eval_barch_size=64,\n",
123
+ " warmup_steps=500,\n",
124
+ " learning_rate=5e-5,\n",
125
+ " weight_decay=0.01,\n",
126
+ " logging_dir=\"./logs\",\n",
127
+ " logging_steps=10\n",
128
+ " )\n",
129
+ "\n",
130
+ "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
131
+ "\n",
132
+ "\n",
133
+ "trainer = Trainer(\n",
134
+ " model=model, \n",
135
+ " args=args, \n",
136
+ " train_dataset=train_dataset, \n",
137
+ " val_dataset=test_dataset)\n",
138
+ "\n",
139
+ "trainer.train()\n"
140
+ ]
141
+ }
142
+ ],
143
+ "metadata": {
144
+ "kernelspec": {
145
+ "display_name": "Python 3 (ipykernel)",
146
+ "language": "python",
147
+ "name": "python3"
148
+ },
149
+ "language_info": {
150
+ "codemirror_mode": {
151
+ "name": "ipython",
152
+ "version": 3
153
+ },
154
+ "file_extension": ".py",
155
+ "mimetype": "text/x-python",
156
+ "name": "python",
157
+ "nbconvert_exporter": "python",
158
+ "pygments_lexer": "ipython3",
159
+ "version": "3.10.6"
160
+ }
161
+ },
162
+ "nbformat": 4,
163
+ "nbformat_minor": 5
164
+ }
data/.~lock.test.csv# ADDED
@@ -0,0 +1 @@
 
 
1
+ ,joe,mint,23.04.2023 12:27,file:///home/joe/.config/libreoffice/4;
data/.~lock.test_labels.csv# ADDED
@@ -0,0 +1 @@
 
 
1
+ ,joe,mint,23.04.2023 11:48,file:///home/joe/.config/libreoffice/4;
data/.~lock.train.csv# ADDED
@@ -0,0 +1 @@
 
 
1
+ ,joe,mint,23.04.2023 11:51,file:///home/joe/.config/libreoffice/4;
training.ipynb ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "215a1aae",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "2023-04-23 12:34:45.188102: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
14
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
15
+ "2023-04-23 12:34:45.742757: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import torch\n",
21
+ "from torch.utils.data import Dataset\n",
22
+ "\n",
23
+ "import pandas as pd\n",
24
+ "# import numpy as np\n",
25
+ "\n",
26
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
27
+ "from transformers import Trainer, TrainingArguments"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 10,
33
+ "id": "9969c58c",
34
+ "metadata": {
35
+ "scrolled": false
36
+ },
37
+ "outputs": [
38
+ {
39
+ "name": "stderr",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "IOPub data rate exceeded.\n",
43
+ "The notebook server will temporarily stop sending output\n",
44
+ "to the client in order to avoid crashing it.\n",
45
+ "To change this limit, set the config variable\n",
46
+ "`--NotebookApp.iopub_data_rate_limit`.\n",
47
+ "\n",
48
+ "Current values:\n",
49
+ "NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
50
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
51
+ "\n",
52
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (631 > 512). Running this sequence through the model will result in indexing errors\n"
53
+ ]
54
+ },
55
+ {
56
+ "ename": "ValueError",
57
+ "evalue": "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).",
58
+ "output_type": "error",
59
+ "traceback": [
60
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
61
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
62
+ "\u001b[0;32m/tmp/ipykernel_325077/677523904.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mtrain_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mtest_encodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_text\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTweetDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_encodings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
63
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2536\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_in_target_context_manager\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2537\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_input_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2538\u001b[0;31m \u001b[0mencodings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_one\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_pair\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtext_pair\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2539\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtext_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2540\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_switch_to_target_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
64
+ "\u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/tokenization_utils_base.py\u001b[0m in \u001b[0;36m_call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2595\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_is_valid_text_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2596\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 2597\u001b[0m \u001b[0;34m\"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2598\u001b[0m \u001b[0;34m\"or `List[List[str]]` (batch of pretokenized examples).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
65
+ "\u001b[0;31mValueError\u001b[0m: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
66
+ ]
67
+ }
68
+ ],
69
+ "source": [
70
+ "model_name = \"bert-base-uncased\"\n",
71
+ "\n",
72
+ "# dataset class that inherits from torch.utils.data.Dataset\n",
73
+ "class TweetDataset(Dataset):\n",
74
+ " def __init__(self, encodings, labels):\n",
75
+ " self.encodings = encodings\n",
76
+ " self.labels = labels\n",
77
+ " \n",
78
+ " def __getitem__(self, idx):\n",
79
+ " item = { key: torch.tensor(val[idx]) for key, val in self.encodings.items() }\n",
80
+ " item['labels'] = torch.tensor(self.labels[idx])\n",
81
+ " return item\n",
82
+ " \n",
83
+ " def __len__(self):\n",
84
+ " return len(self.labels)\n",
85
+ " \n",
86
+ "\n",
87
+ "\n",
88
+ "train_data = pd.read_csv(\"data/train.csv\")\n",
89
+ "train_text = train_data[\"comment_text\"].values.tolist()\n",
90
+ "train_labels = train_data[[\"toxic\", \"severe_toxic\", \n",
91
+ " \"obscene\", \"threat\", \n",
92
+ " \"insult\", \"identity_hate\"]].values.tolist()\n",
93
+ "\n",
94
+ "test_text = pd.read_csv(\"data/test.csv\")[\"comment_text\"].values.tolist()\n",
95
+ "test_labels = pd.read_csv(\"data/test_labels.csv\")[[\n",
96
+ " \"toxic\", \"severe_toxic\", \n",
97
+ " \"obscene\", \"threat\", \n",
98
+ " \"insult\", \"identity_hate\"]].values.tolist()\n",
99
+ "\n",
100
+ "\n",
101
+ "# prepare tokenizer and dataset\n",
102
+ "\n",
103
+ "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
104
+ "\n",
105
+ "print(train_text)\n",
106
+ "\n",
107
+ "\n",
108
+ "train_encodings = tokenizer(train_text)\n",
109
+ "test_encodings = tokenizer(test_text)\n",
110
+ "\n",
111
+ "train_dataset = TweetDataset(train_encodings, train_labels)\n",
112
+ "test_dataset = TweetDataset(test_encodings, test_labels)\n",
113
+ "\n",
114
+ "\n",
115
+ "# training\n",
116
+ "\n",
117
+ "\n",
118
+ "training_args = TrainingArguments(\n",
119
+ " output_dir=\"results\",\n",
120
+ " num_train_epochs=2,\n",
121
+ " per_device_train_batch_size=16,\n",
122
+ " per_device_eval_barch_size=64,\n",
123
+ " warmup_steps=500,\n",
124
+ " learning_rate=5e-5,\n",
125
+ " weight_decay=0.01,\n",
126
+ " logging_dir=\"./logs\",\n",
127
+ " logging_steps=10\n",
128
+ " )\n",
129
+ "\n",
130
+ "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6)\n",
131
+ "\n",
132
+ "\n",
133
+ "trainer = Trainer(\n",
134
+ " model=model, \n",
135
+ " args=args, \n",
136
+ " train_dataset=train_dataset, \n",
137
+ " val_dataset=test_dataset)\n",
138
+ "\n",
139
+ "trainer.train()\n"
140
+ ]
141
+ }
142
+ ],
143
+ "metadata": {
144
+ "kernelspec": {
145
+ "display_name": "Python 3 (ipykernel)",
146
+ "language": "python",
147
+ "name": "python3"
148
+ },
149
+ "language_info": {
150
+ "codemirror_mode": {
151
+ "name": "ipython",
152
+ "version": 3
153
+ },
154
+ "file_extension": ".py",
155
+ "mimetype": "text/x-python",
156
+ "name": "python",
157
+ "nbconvert_exporter": "python",
158
+ "pygments_lexer": "ipython3",
159
+ "version": "3.10.6"
160
+ }
161
+ },
162
+ "nbformat": 4,
163
+ "nbformat_minor": 5
164
+ }