liwii commited on
Commit
2e36ddd
1 Parent(s): 4526c7e

Training in progress, epoch 1

Browse files
added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 0,
3
+ "<unk>": 1,
4
+ "[CLS]": 2,
5
+ "[MASK]": 4,
6
+ "[SEP]": 3
7
+ }
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "line-corporation/line-distilbert-base-japanese",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "ConsistentSentenceClassifier"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "contradiction",
13
+ "1": "neutral",
14
+ "2": "entailment"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "label2id": {
18
+ "contradiction": 0,
19
+ "entailment": 2,
20
+ "neutral": 1
21
+ },
22
+ "max_position_embeddings": 512,
23
+ "model_type": "distilbert",
24
+ "n_heads": 12,
25
+ "n_layers": 6,
26
+ "output_hidden_states": true,
27
+ "pad_token_id": 0,
28
+ "problem_type": "single_label_classification",
29
+ "qa_dropout": 0.1,
30
+ "seq_classif_dropout": 0.2,
31
+ "sinusoidal_pos_embds": true,
32
+ "tie_weights_": true,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.34.0",
35
+ "vocab_size": 32768
36
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daf276874da332369d4beb0dde2f6db9d412102c69746a1cce536d0d3851b783
3
+ size 274758317
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "<pad>",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "<unk>"
9
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcfafc8c0662d9c8f39621a64c74260f2ad120310c8dd24886de2dddaf599b4e
3
+ size 439391
tokenizer_config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "auto_map": {
46
+ "AutoTokenizer": [
47
+ "line-corporation/line-distilbert-base-japanese--distilbert_japanese_tokenizer.DistilBertJapaneseTokenizer",
48
+ null
49
+ ]
50
+ },
51
+ "bos_token": "[CLS]",
52
+ "clean_up_tokenization_spaces": true,
53
+ "cls_token": "[CLS]",
54
+ "do_lower_case": true,
55
+ "do_subword_tokenize": true,
56
+ "do_word_tokenize": true,
57
+ "eos_token": "[SEP]",
58
+ "jumanpp_kwargs": null,
59
+ "keep_accents": true,
60
+ "mask_token": "[MASK]",
61
+ "mecab_kwargs": {
62
+ "mecab_dic": "unidic_lite"
63
+ },
64
+ "model_max_length": 1000000000000000019884624838656,
65
+ "never_split": null,
66
+ "pad_token": "<pad>",
67
+ "remove_space": true,
68
+ "sep_token": "[SEP]",
69
+ "subword_tokenizer_type": "sentencepiece",
70
+ "sudachi_kwargs": null,
71
+ "tokenize_chinese_chars": false,
72
+ "tokenizer_class": "BertJapaneseTokenizer",
73
+ "tokenizer_file": null,
74
+ "unk_token": "<unk>",
75
+ "word_tokenizer_type": "mecab"
76
+ }
train-v1.1.json ADDED
The diff for this file is too large to render. See raw diff
 
train_factual_consistency.ipynb ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b12ae8a3-9e08-402c-894c-31697fad6c56",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "8950b2cbd1c44912917219b84af806ce",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
+ "source": [
25
+ "from huggingface_hub import notebook_login\n",
26
+ "\n",
27
+ "notebook_login()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "160c80c1-0ca4-45df-8171-87cd3c88a223",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "\n",
38
+ "from transformers import (\n",
39
+ " AutoTokenizer,\n",
40
+ " DataCollatorWithPadding,\n",
41
+ " Trainer,\n",
42
+ " TrainingArguments,\n",
43
+ ")\n",
44
+ "from utils import ConsistentSentenceClassifier, get_metrics, load_dataset"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "25800588-5d42-4524-9dc6-a6a0c180b8b0",
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ " text label\n",
58
+ "512 カーキ色の服を着た男性が、口元にリンゴを当てています。[SEP]カーキ色の服を着た男性が、口... 0\n",
59
+ "513 男性がグラウンドでボールを投げています。[SEP]白い髯を生やした男性がボールを投げています。 1\n",
60
+ "514 椅子に座った子供が、手づかみで食事をしています。[SEP]椅子に座った子供が手づかみで、食事... 2\n",
61
+ "515 プロペラ機が何台も駐機しています。[SEP]プロペラ機が何台も連なって飛んでいます。 0\n",
62
+ "516 消火栓から水が勢いよく噴き出しています。[SEP]水が噴き出している消火栓の水を浴びるように... 1\n",
63
+ "517 冷蔵庫のないキッチンにナイフとフォークが置かれています。[SEP]冷蔵庫の置かれたキッチンに... 0\n",
64
+ "518 うみでサーフィンをしているひとがいます。[SEP]黒いウェットスーツを着た人がサーフボードに... 1\n",
65
+ "519 池から白い鳥が飛び立っています。[SEP]森にある水の上を鳥が飛んでいます。 1\n",
66
+ "520 丈夫なビーチパラソルが立っています。[SEP]ビーチパラソルの支柱が折れ曲がっています。 0\n",
67
+ "521 白髪の男性が少女から花束を受け取っています。[SEP]花束を持った男性の前に多くの子供たちが... 1\n",
68
+ " text label\n",
69
+ "0 赤いひとつの傘に、二人の人が入っています。[SEP]歩道を歩く通行人が傘をさして歩いています。 1\n",
70
+ "1 川を小さなボートが進んで行きます。[SEP]川を豪華客船が進んでいきます。 0\n",
71
+ "2 ゲレンデのこぶでスキージャンプしています。[SEP]雪上でモーグルを楽しむ水色のウェアを着た女性。 1\n",
72
+ "3 黒いお皿に乗っているピザをカットしています。[SEP]黒い皿の上にピザが盛られています。 2\n",
73
+ "4 女性が目を細めて携帯電話で話をしています。[SEP]目を細めた女性が携帯電話で話をしています。 2\n",
74
+ "5 バナナやパパイヤなどの果物が売られている。[SEP]台の上にはバナナなどの青果が並べられています。 1\n",
75
+ "6 ヘッドライトを点灯させた白いバスが駐車場に止まっています。[SEP]ライトを点灯させているバ... 2\n",
76
+ "7 水面の上に、カイトサーフィンの凧が揚がっています。[SEP]海の上に水上スポーツ用の凧が揚が... 1\n",
77
+ "8 ホットドッグを野外で食べている人たちです。[SEP]家の中でホットドッグを食べている。 0\n",
78
+ "9 草が生い茂っている所に、3頭のゾウがいます。[SEP]草むらの中に三頭のゾウが立っているとこ... 1\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "0e80998f107e4e9a80886e682dc73c5a",
85
+ "version_major": 2,
86
+ "version_minor": 0
87
+ },
88
+ "text/plain": [
89
+ "Map: 0%| | 0/19561 [00:00<?, ? examples/s]"
90
+ ]
91
+ },
92
+ "metadata": {},
93
+ "output_type": "display_data"
94
+ },
95
+ {
96
+ "name": "stderr",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.\n",
100
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
101
+ ]
102
+ },
103
+ {
104
+ "data": {
105
+ "application/vnd.jupyter.widget-view+json": {
106
+ "model_id": "c8bb0daae6cf4c50bf2631a47324a8fb",
107
+ "version_major": 2,
108
+ "version_minor": 0
109
+ },
110
+ "text/plain": [
111
+ "Map: 0%| | 0/512 [00:00<?, ? examples/s]"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ }
117
+ ],
118
+ "source": [
119
+ "tokenizer = AutoTokenizer.from_pretrained(\"line-corporation/line-distilbert-base-japanese\")\n",
120
+ "dataset = load_dataset('train-v1.1.json')\n",
121
+ "tokenized_dataset = dataset.map(\n",
122
+ " lambda examples: tokenizer(examples[\"text\"], padding='max_length', truncation=True), batched=True\n",
123
+ ")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "6bc83d4c-378c-4313-b641-8ead0c02f715",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "name": "stderr",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "WARNING:root:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device or configure XRT. To disable default device selection, set PJRT_SELECT_DEFAULT_DEVICE=0\n",
137
+ "WARNING:root:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md\n",
138
+ "WARNING:root:Defaulting to PJRT_DEVICE=CPU\n"
139
+ ]
140
+ },
141
+ {
142
+ "data": {
143
+ "text/html": [
144
+ "\n",
145
+ " <div>\n",
146
+ " \n",
147
+ " <progress value='198' max='9180' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
148
+ " [ 198/9180 01:00 < 46:32, 3.22 it/s, Epoch 0.64/30]\n",
149
+ " </div>\n",
150
+ " <table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: left;\">\n",
153
+ " <th>Epoch</th>\n",
154
+ " <th>Training Loss</th>\n",
155
+ " <th>Validation Loss</th>\n",
156
+ " </tr>\n",
157
+ " </thead>\n",
158
+ " <tbody>\n",
159
+ " </tbody>\n",
160
+ "</table><p>"
161
+ ],
162
+ "text/plain": [
163
+ "<IPython.core.display.HTML object>"
164
+ ]
165
+ },
166
+ "metadata": {},
167
+ "output_type": "display_data"
168
+ }
169
+ ],
170
+ "source": [
171
+ "model = ConsistentSentenceClassifier(\n",
172
+ " freeze_bert=False)\n",
173
+ "\n",
174
+ "training_args = TrainingArguments(\n",
175
+ " output_dir=\"../factual-consistency-classification-ja-avgpool-unfrozen\",\n",
176
+ " learning_rate=1e-4,\n",
177
+ " per_device_train_batch_size=64,\n",
178
+ " per_device_eval_batch_size=8,\n",
179
+ " num_train_epochs=30,\n",
180
+ " weight_decay=0.02,\n",
181
+ " evaluation_strategy=\"epoch\",\n",
182
+ " eval_accumulation_steps=4,\n",
183
+ " save_strategy=\"epoch\",\n",
184
+ " load_best_model_at_end=True,\n",
185
+ " save_total_limit=5,\n",
186
+ " push_to_hub=True,\n",
187
+ ")\n",
188
+ "\n",
189
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
190
+ "trainer = Trainer(\n",
191
+ " model=model,\n",
192
+ " args=training_args,\n",
193
+ " train_dataset=tokenized_dataset[\"train\"],\n",
194
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
195
+ " tokenizer=tokenizer,\n",
196
+ " data_collator=data_collator,\n",
197
+ " compute_metrics=get_metrics(),\n",
198
+ ")\n",
199
+ "\n",
200
+ "trainer.train()\n",
201
+ "trainer.push_to_hub('factual-consistency-classification-ja-avgpool-unfrozen')"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "id": "a6eb93f7-5a38-49a2-be0d-e42267e23a0a",
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": []
211
+ }
212
+ ],
213
+ "metadata": {
214
+ "environment": {
215
+ "kernel": "python3",
216
+ "name": "pytorch-gpu.2-0.m112",
217
+ "type": "gcloud",
218
+ "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.2-0:m112"
219
+ },
220
+ "kernelspec": {
221
+ "display_name": "Python 3",
222
+ "language": "python",
223
+ "name": "python3"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.10.12"
236
+ }
237
+ },
238
+ "nbformat": 4,
239
+ "nbformat_minor": 5
240
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16a8fea2d112223bb5fc50f0e3b8457dcd3eefa65312f57e80712e85717a5f1f
3
+ size 4155
utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import datasets
4
+ import numpy as np
5
+ import evaluate
6
+ import torch
7
+ from transformers import AutoModel, DistilBertForSequenceClassification
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from typing import Optional
10
+
11
+ SEP_TOKEN = '[SEP]'
12
+ LABEL2ID = {'entailment': 2, 'neutral': 1, 'contradiction': 0}
13
+ ID2LABEL = {2: 'entailment', 1: 'neutral', 0: 'contradiction'}
14
+
15
+ def format_dataset(arr):
16
+ text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
17
+ label = [LABEL2ID[el['label']] for el in arr]
18
+ new_df = pd.DataFrame({'text': text, 'label': label})
19
+ return new_df.sample(frac=1, random_state=42).reset_index(drop=True)
20
+
21
+ # Load dataset
22
+ def load_dataset(path):
23
+ train_array = []
24
+ with open(path) as f:
25
+ for line in f.readlines():
26
+ if line:
27
+ train_array.append(json.loads(line))
28
+ df = format_dataset(train_array)
29
+ # Split dataset into train and val
30
+ df_train = df.iloc[512:, :]
31
+ # We do not need much test data
32
+ df_test = df.iloc[:512, :]
33
+ print(df_train[:10])
34
+ print(df_test[:10])
35
+
36
+ factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
37
+ factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
38
+ df_train[["text", "label"]])
39
+ factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
40
+ df_test[["text", "label"]])
41
+
42
+ return factual_consistency_dataset
43
+
44
+
45
+ class ConsistentSentenceClassifier(DistilBertForSequenceClassification):
46
+
47
+ def __init__(self, freeze_bert=True):
48
+ base_model = AutoModel.from_pretrained(
49
+ 'line-corporation/line-distilbert-base-japanese', num_labels=3)
50
+
51
+ config = base_model.config
52
+ super(ConsistentSentenceClassifier, self).__init__(config=config)
53
+ config.num_labels = 3
54
+ config.id2label = ID2LABEL
55
+ config.label2id = LABEL2ID
56
+ config.problem_type = "single_label_classification"
57
+
58
+ self.distilbert = base_model
59
+
60
+ if not freeze_bert:
61
+ return
62
+
63
+ for param in self.distilbert.parameters():
64
+ param.requires_grad = False
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ head_mask: Optional[torch.Tensor] = None,
71
+ inputs_embeds: Optional[torch.Tensor] = None,
72
+ labels: Optional[torch.LongTensor] = None,
73
+ output_attentions: Optional[bool] = None,
74
+ output_hidden_states: Optional[bool] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ):
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+
79
+ distilbert_output = self.distilbert(
80
+ input_ids=input_ids,
81
+ attention_mask=attention_mask,
82
+ head_mask=head_mask,
83
+ inputs_embeds=inputs_embeds,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ )
88
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
89
+ pooled_output = torch.mean(hidden_state, dim=1)
90
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
91
+ pooled_output = torch.nn.ReLU()(pooled_output) # (bs, dim)
92
+ pooled_output = self.dropout(pooled_output) # (bs, dim)
93
+ logits = self.classifier(pooled_output) # (bs, num_labels)
94
+
95
+ loss = None
96
+ if labels is not None:
97
+ loss_fct = torch.nn.CrossEntropyLoss()
98
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
99
+
100
+ if not return_dict:
101
+ output = (logits,) + distilbert_output[1:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=distilbert_output.hidden_states,
108
+ attentions=distilbert_output.attentions,
109
+ )
110
+
111
+
112
+
113
+ # Set up evaluation metridef get_metrics():
114
+
115
+ def get_metrics():
116
+ metric = evaluate.load("accuracy")
117
+
118
+ def compute_metrics(eval_pred):
119
+ predictions, labels = eval_pred
120
+ preds = predictions[0].argmax(axis=1)
121
+ return metric.compute(predictions=preds, references=labels)
122
+
123
+ return compute_metrics