pranay-j commited on
Commit
a6e171e
·
verified ·
1 Parent(s): 5c85b70

Training in progress, epoch 0

Browse files
.ipynb_checkpoints/finetuning_text_classification-checkpoint.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d090c366-23e5-4221-a868-f290eefcedc2",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "from datasets import load_dataset\n",
20
+ "\n",
21
+ "dataset = load_dataset(\"google/boolq\")"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "a6bad310-9514-4468-bdca-673b30dfd473",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from transformers import AutoTokenizer\n",
32
+ "tokenizer=AutoTokenizer.from_pretrained(\"bert-base-uncased\")"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "013559ce-c991-4836-922c-5f9201265c66",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "dataset"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "38aac997-3d15-4e61-b80c-c1a4fff0b525",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "dataset[\"train\"][0]"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "f4d214cd-2fef-4778-bc3a-cb4e1c907515",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "def encode_question_context_pairs(example):\n",
63
+ " text=f'{example[\"question\"]} [SEP] {example[\"passage\"]}'\n",
64
+ " label= 0 if not example[\"answer\"] else 1\n",
65
+ " inputs=tokenizer(text,truncation=True)\n",
66
+ " inputs[\"labels\"]=[float(label)]\n",
67
+ " return inputs"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "6fa2aa41-6286-4a69-ba23-90482d98f494",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "train_dataset=dataset[\"train\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "309bee55-b698-4c66-990d-beb00ac52746",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "validation_dataset=dataset[\"validation\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "bf95690a-4ed4-4635-9b39-12bc4b486b5f",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# train_dataset['labels']"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "00c07517-6976-4553-8188-2b7f4078adf3",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": []
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "1371cc4a-3f0e-4e84-939b-218b570c0b6b",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": []
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "85c9ccea-f788-4025-b185-c32c6fa51c46",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "# tokenizer(\"question\",\"answer\",max_length=512,padding=\"max_length\",truncation=\"only_second\",)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "30a82635-f956-404d-a95e-db753f7e07b7",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "from transformers import DataCollatorWithPadding\n",
134
+ "\n",
135
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "22d43e81-1739-443f-95fb-ee98b10a3a0b",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "import evaluate\n",
146
+ "\n",
147
+ "accuracy = evaluate.load(\"accuracy\")"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "id": "23fa9362-aa3d-4155-85a5-6caa6635c9f8",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "import numpy as np\n",
158
+ "\n",
159
+ "\n",
160
+ "def compute_metrics(eval_pred):\n",
161
+ " predictions, labels = eval_pred\n",
162
+ " predictions = np.where(predictions<0.5,0,1)\n",
163
+ " return accuracy.compute(predictions=predictions, references=labels)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "e476c76f-21b6-4844-a6a5-29f18b4f6099",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
174
+ "\n",
175
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
176
+ " \"bert-base-uncased\", num_labels=1,\n",
177
+ ")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "5a359a0d-7563-4f4e-b4d4-03e6c601fc2f",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "training_args = TrainingArguments(\n",
188
+ " output_dir=\"./\",\n",
189
+ " learning_rate=2e-5,\n",
190
+ " per_device_train_batch_size=16,\n",
191
+ " per_device_eval_batch_size=16,\n",
192
+ " num_train_epochs=4,\n",
193
+ " weight_decay=0.01,\n",
194
+ " evaluation_strategy=\"epoch\",\n",
195
+ " save_strategy=\"epoch\",\n",
196
+ " load_best_model_at_end=True,\n",
197
+ " gradient_accumulation_steps=4,\n",
198
+ " logging_steps=50,\n",
199
+ " seed=42,\n",
200
+ " adam_beta1= 0.9,\n",
201
+ " adam_beta2= 0.999,\n",
202
+ " adam_epsilon= 1e-08,\n",
203
+ " report_to=\"tensorboard\",\n",
204
+ " push_to_hub=True,\n",
205
+ ")\n",
206
+ "\n",
207
+ "trainer = Trainer(\n",
208
+ " model=model,\n",
209
+ " args=training_args,\n",
210
+ " train_dataset=train_dataset,\n",
211
+ " eval_dataset=validation_dataset,\n",
212
+ " tokenizer=tokenizer,\n",
213
+ " data_collator=data_collator,\n",
214
+ " compute_metrics=compute_metrics,\n",
215
+ ")\n",
216
+ "\n",
217
+ "# trainer.train()"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "id": "0bc0fca5-d298-40d3-a80b-035a05fe6e1f",
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "model.save_pretrained(training_args.output_dir)\n",
228
+ "tokenizer.save_pretrained(training_args.output_dir)"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "id": "c96926e2-04c1-4e33-b83f-dc2b9c4d5b08",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "trainer.train()"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "75e96eb2-0d8e-4e5f-8844-6abce16bd1cb",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "kwargs = {\n",
249
+ " \"dataset_tags\": \"google/boolq\",\n",
250
+ " \"dataset\": \"boolq\", # a 'pretty' name for the training dataset\n",
251
+ " \"language\": \"en\",\n",
252
+ " \"model_name\": \"Bert Base Uncased Boolean Question Answer model\", # a 'pretty' name for your model\n",
253
+ " \"finetuned_from\": \"bert-base-uncased\",\n",
254
+ " \"tasks\": \"text-classification\",\n",
255
+ "}"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "ba5e73bd-d154-43ce-a869-f0f57045a386",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "trainer.push_to_hub(**kwargs)"
266
+ ]
267
+ }
268
+ ],
269
+ "metadata": {
270
+ "kernelspec": {
271
+ "display_name": "Python 3 (ipykernel)",
272
+ "language": "python",
273
+ "name": "python3"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.10.12"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 5
290
+ }
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-uncased",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "LABEL_0"
14
+ },
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "label2id": {
18
+ "LABEL_0": 0
19
+ },
20
+ "layer_norm_eps": 1e-12,
21
+ "max_position_embeddings": 512,
22
+ "model_type": "bert",
23
+ "num_attention_heads": 12,
24
+ "num_hidden_layers": 12,
25
+ "pad_token_id": 0,
26
+ "position_embedding_type": "absolute",
27
+ "problem_type": "regression",
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.40.0",
30
+ "type_vocab_size": 2,
31
+ "use_cache": true,
32
+ "vocab_size": 30522
33
+ }
finetuning_text_classification.ipynb ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d090c366-23e5-4221-a868-f290eefcedc2",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "from datasets import load_dataset\n",
20
+ "\n",
21
+ "dataset = load_dataset(\"google/boolq\")"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "id": "a6bad310-9514-4468-bdca-673b30dfd473",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "from transformers import AutoTokenizer\n",
32
+ "tokenizer=AutoTokenizer.from_pretrained(\"bert-base-uncased\")"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "013559ce-c991-4836-922c-5f9201265c66",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "data": {
43
+ "text/plain": [
44
+ "DatasetDict({\n",
45
+ " train: Dataset({\n",
46
+ " features: ['question', 'answer', 'passage'],\n",
47
+ " num_rows: 9427\n",
48
+ " })\n",
49
+ " validation: Dataset({\n",
50
+ " features: ['question', 'answer', 'passage'],\n",
51
+ " num_rows: 3270\n",
52
+ " })\n",
53
+ "})"
54
+ ]
55
+ },
56
+ "execution_count": 3,
57
+ "metadata": {},
58
+ "output_type": "execute_result"
59
+ }
60
+ ],
61
+ "source": [
62
+ "dataset"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 4,
68
+ "id": "38aac997-3d15-4e61-b80c-c1a4fff0b525",
69
+ "metadata": {},
70
+ "outputs": [
71
+ {
72
+ "data": {
73
+ "text/plain": [
74
+ "{'question': 'do iran and afghanistan speak the same language',\n",
75
+ " 'answer': True,\n",
76
+ " 'passage': 'Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.'}"
77
+ ]
78
+ },
79
+ "execution_count": 4,
80
+ "metadata": {},
81
+ "output_type": "execute_result"
82
+ }
83
+ ],
84
+ "source": [
85
+ "dataset[\"train\"][0]"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 5,
91
+ "id": "f4d214cd-2fef-4778-bc3a-cb4e1c907515",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "def encode_question_context_pairs(example):\n",
96
+ " text=f'{example[\"question\"]} [SEP] {example[\"passage\"]}'\n",
97
+ " label= 0 if not example[\"answer\"] else 1\n",
98
+ " inputs=tokenizer(text,truncation=True)\n",
99
+ " inputs[\"labels\"]=[float(label)]\n",
100
+ " return inputs"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 6,
106
+ "id": "6fa2aa41-6286-4a69-ba23-90482d98f494",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "train_dataset=dataset[\"train\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 7,
116
+ "id": "309bee55-b698-4c66-990d-beb00ac52746",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "validation_dataset=dataset[\"validation\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 8,
126
+ "id": "bf95690a-4ed4-4635-9b39-12bc4b486b5f",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# train_dataset['labels']"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "00c07517-6976-4553-8188-2b7f4078adf3",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": []
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "1371cc4a-3f0e-4e84-939b-218b570c0b6b",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": []
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 9,
152
+ "id": "85c9ccea-f788-4025-b185-c32c6fa51c46",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "# tokenizer(\"question\",\"answer\",max_length=512,padding=\"max_length\",truncation=\"only_second\",)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 10,
162
+ "id": "30a82635-f956-404d-a95e-db753f7e07b7",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "from transformers import DataCollatorWithPadding\n",
167
+ "\n",
168
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 11,
174
+ "id": "22d43e81-1739-443f-95fb-ee98b10a3a0b",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "import evaluate\n",
179
+ "\n",
180
+ "accuracy = evaluate.load(\"accuracy\")"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 12,
186
+ "id": "23fa9362-aa3d-4155-85a5-6caa6635c9f8",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "import numpy as np\n",
191
+ "\n",
192
+ "\n",
193
+ "def compute_metrics(eval_pred):\n",
194
+ " predictions, labels = eval_pred\n",
195
+ " predictions = np.where(predictions<0.5,0,1)\n",
196
+ " return accuracy.compute(predictions=predictions, references=labels)"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 13,
202
+ "id": "e476c76f-21b6-4844-a6a5-29f18b4f6099",
203
+ "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "name": "stderr",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
210
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
216
+ "\n",
217
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
218
+ " \"bert-base-uncased\", num_labels=1,\n",
219
+ ")"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 14,
225
+ "id": "5a359a0d-7563-4f4e-b4d4-03e6c601fc2f",
226
+ "metadata": {},
227
+ "outputs": [],
228
+ "source": [
229
+ "training_args = TrainingArguments(\n",
230
+ " output_dir=\"./\",\n",
231
+ " learning_rate=2e-5,\n",
232
+ " per_device_train_batch_size=16,\n",
233
+ " per_device_eval_batch_size=16,\n",
234
+ " num_train_epochs=4,\n",
235
+ " weight_decay=0.01,\n",
236
+ " evaluation_strategy=\"epoch\",\n",
237
+ " save_strategy=\"epoch\",\n",
238
+ " load_best_model_at_end=True,\n",
239
+ " gradient_accumulation_steps=4,\n",
240
+ " logging_steps=50,\n",
241
+ " seed=42,\n",
242
+ " adam_beta1= 0.9,\n",
243
+ " adam_beta2= 0.999,\n",
244
+ " adam_epsilon= 1e-08,\n",
245
+ " report_to=\"tensorboard\",\n",
246
+ " push_to_hub=True,\n",
247
+ ")\n",
248
+ "\n",
249
+ "trainer = Trainer(\n",
250
+ " model=model,\n",
251
+ " args=training_args,\n",
252
+ " train_dataset=train_dataset,\n",
253
+ " eval_dataset=validation_dataset,\n",
254
+ " tokenizer=tokenizer,\n",
255
+ " data_collator=data_collator,\n",
256
+ " compute_metrics=compute_metrics,\n",
257
+ ")\n",
258
+ "\n",
259
+ "# trainer.train()"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 15,
265
+ "id": "0bc0fca5-d298-40d3-a80b-035a05fe6e1f",
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "data": {
270
+ "text/plain": [
271
+ "('./tokenizer_config.json',\n",
272
+ " './special_tokens_map.json',\n",
273
+ " './vocab.txt',\n",
274
+ " './added_tokens.json',\n",
275
+ " './tokenizer.json')"
276
+ ]
277
+ },
278
+ "execution_count": 15,
279
+ "metadata": {},
280
+ "output_type": "execute_result"
281
+ }
282
+ ],
283
+ "source": [
284
+ "model.save_pretrained(training_args.output_dir)\n",
285
+ "tokenizer.save_pretrained(training_args.output_dir)"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": null,
291
+ "id": "c96926e2-04c1-4e33-b83f-dc2b9c4d5b08",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "\n",
298
+ " <div>\n",
299
+ " \n",
300
+ " <progress value='148' max='588' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
301
+ " [148/588 07:00 < 21:07, 0.35 it/s, Epoch 1.00/4]\n",
302
+ " </div>\n",
303
+ " <table border=\"1\" class=\"dataframe\">\n",
304
+ " <thead>\n",
305
+ " <tr style=\"text-align: left;\">\n",
306
+ " <th>Epoch</th>\n",
307
+ " <th>Training Loss</th>\n",
308
+ " <th>Validation Loss</th>\n",
309
+ " </tr>\n",
310
+ " </thead>\n",
311
+ " <tbody>\n",
312
+ " </tbody>\n",
313
+ "</table><p>\n",
314
+ " <div>\n",
315
+ " \n",
316
+ " <progress value='102' max='205' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
317
+ " [102/205 00:26 < 00:27, 3.76 it/s]\n",
318
+ " </div>\n",
319
+ " "
320
+ ],
321
+ "text/plain": [
322
+ "<IPython.core.display.HTML object>"
323
+ ]
324
+ },
325
+ "metadata": {},
326
+ "output_type": "display_data"
327
+ }
328
+ ],
329
+ "source": [
330
+ "trainer.train()"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "id": "75e96eb2-0d8e-4e5f-8844-6abce16bd1cb",
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "kwargs = {\n",
341
+ " \"dataset_tags\": \"google/boolq\",\n",
342
+ " \"dataset\": \"boolq\", # a 'pretty' name for the training dataset\n",
343
+ " \"language\": \"en\",\n",
344
+ " \"model_name\": \"Bert Base Uncased Boolean Question Answer model\", # a 'pretty' name for your model\n",
345
+ " \"finetuned_from\": \"bert-base-uncased\",\n",
346
+ " \"tasks\": \"text-classification\",\n",
347
+ "}"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "id": "ba5e73bd-d154-43ce-a869-f0f57045a386",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "trainer.push_to_hub(**kwargs)"
358
+ ]
359
+ }
360
+ ],
361
+ "metadata": {
362
+ "kernelspec": {
363
+ "display_name": "Python 3 (ipykernel)",
364
+ "language": "python",
365
+ "name": "python3"
366
+ },
367
+ "language_info": {
368
+ "codemirror_mode": {
369
+ "name": "ipython",
370
+ "version": 3
371
+ },
372
+ "file_extension": ".py",
373
+ "mimetype": "text/x-python",
374
+ "name": "python",
375
+ "nbconvert_exporter": "python",
376
+ "pygments_lexer": "ipython3",
377
+ "version": "3.10.12"
378
+ }
379
+ },
380
+ "nbformat": 4,
381
+ "nbformat_minor": 5
382
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf66bfdcb9c367e6e70099af4ba926f4e7636c1562231ef2f272ec545f4ebc75
3
+ size 437955572
runs/Apr20_13-50-06_386b24d31d4c/events.out.tfevents.1713621007.386b24d31d4c ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c9560a1d93a3251de9c8ec432f164a5da072e66ab34c30898a5b5dfe5f7f181
3
+ size 4693
runs/Apr20_13-51-30_386b24d31d4c/events.out.tfevents.1713621092.386b24d31d4c ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ecca68f93aaf9ce7b644564aec6d87f2caec4cefe1abc0129e39ec387ebe0c5
3
+ size 4693
runs/Apr20_13-54-34_386b24d31d4c/events.out.tfevents.1713621277.386b24d31d4c ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:381504dba9f24f32546e28f5aeacbe7e397bbbb4b156fb4f62aa4ef5f0457e2b
3
+ size 5430
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "BertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:772843bd8850c3c86ff160e8c7e9457e2b3e5a7bf7bf27f2b2a6453662366a70
3
+ size 4984
vocab.txt ADDED
The diff for this file is too large to render. See raw diff