tikim commited on
Commit
645fa57
1 Parent(s): dea46fb

Add train and test codes

Browse files
Files changed (3) hide show
  1. test.py +46 -0
  2. test_eval.ipynb +183 -0
  3. training.ipynb +261 -0
test.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import(
2
+ EncoderDecoderModel,
3
+ PreTrainedTokenizerFast,
4
+ # XLMRobertaTokenizerFast,
5
+ BertJapaneseTokenizer,
6
+ BertTokenizerFast,
7
+ )
8
+
9
+ import pandas as pd
10
+ csv_test = pd.read_csv('./output/ffac_full.csv')
11
+ # csv_test = pd.read_csv('ffac_test.csv')
12
+
13
+ import csv
14
+
15
+ encoder_model_name = "cl-tohoku/bert-base-japanese-v2"
16
+ decoder_model_name = "skt/kogpt2-base-v2"
17
+
18
+ src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)
19
+ trg_tokenizer = PreTrainedTokenizerFast.from_pretrained(decoder_model_name)
20
+ model = EncoderDecoderModel.from_pretrained("./dump/best_model")
21
+
22
+ def main():
23
+ data_test = []
24
+ data_test_label = []
25
+ data_test_infer = []
26
+ for row in csv_test.itertuples():
27
+ data_test.append(row[1])
28
+ data_test_label.append(row[2])
29
+
30
+ for text in data_test:
31
+ embeddings = src_tokenizer(text, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')
32
+ embeddings = {k: v for k, v in embeddings.items()}
33
+ output = model.generate(**embeddings)[0, 1:-1]
34
+ result = trg_tokenizer.decode(output.cpu())
35
+ # print(result)
36
+ data_test_infer.append(result)
37
+
38
+ rows = zip(data_test, data_test_infer, data_test_label)
39
+ with open('test_result.csv', 'w') as f:
40
+ writer = csv.writer(f)
41
+ writer.writerow(['text', 'inference', 'answer'])
42
+ for row in rows:
43
+ writer.writerow(row)
44
+
45
+ if __name__ == "__main__":
46
+ main()
test_eval.ipynb ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Inference"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 6,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "from transformers import(\n",
18
+ " EncoderDecoderModel,\n",
19
+ " PreTrainedTokenizerFast,\n",
20
+ " # XLMRobertaTokenizerFast,\n",
21
+ " BertJapaneseTokenizer,\n",
22
+ " BertTokenizerFast,\n",
23
+ ")\n",
24
+ "\n",
25
+ "import torch\n",
26
+ "import csv"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 7,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "name": "stderr",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
39
+ "The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. \n",
40
+ "The class this function is called from is 'PreTrainedTokenizerFast'.\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
46
+ "decoder_model_name = \"skt/kogpt2-base-v2\"\n",
47
+ "\n",
48
+ "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
49
+ "trg_tokenizer = PreTrainedTokenizerFast.from_pretrained(decoder_model_name)\n",
50
+ "model = EncoderDecoderModel.from_pretrained(\"./dump/best_model\")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 12,
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "data": {
60
+ "text/plain": [
61
+ "'길가메시 토벌전'"
62
+ ]
63
+ },
64
+ "execution_count": 12,
65
+ "metadata": {},
66
+ "output_type": "execute_result"
67
+ }
68
+ ],
69
+ "source": [
70
+ "text = \"ギルガメッシュ討伐戦\"\n",
71
+ "# text = \"ギルガメッシュ討伐戦に行ってきます。一緒に行きましょうか?\"\n",
72
+ "\n",
73
+ "def translate(text_src):\n",
74
+ " embeddings = src_tokenizer(text_src, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')\n",
75
+ " embeddings = {k: v for k, v in embeddings.items()}\n",
76
+ " output = model.generate(**embeddings)[0, 1:-1]\n",
77
+ " text_trg = trg_tokenizer.decode(output.cpu())\n",
78
+ " return text_trg\n",
79
+ "\n",
80
+ "print(translate(text))"
81
+ ]
82
+ },
83
+ {
84
+ "attachments": {},
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "# Evaluation"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 4,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
98
+ "smoothie = SmoothingFunction().method4"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 5,
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "name": "stderr",
108
+ "output_type": "stream",
109
+ "text": [
110
+ "Testing: 0%| | 0/267 [00:00<?, ?it/s]/home/tikim/.local/lib/python3.8/site-packages/transformers/generation/utils.py:1288: UserWarning: Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
111
+ " warnings.warn(\n",
112
+ "Testing: 100%|██████████| 267/267 [01:01<00:00, 4.34it/s]"
113
+ ]
114
+ },
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Bleu score: 0.9619225967540574\n"
120
+ ]
121
+ },
122
+ {
123
+ "name": "stderr",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "from tqdm import tqdm\n",
132
+ "from statistics import mean\n",
133
+ "\n",
134
+ "bleu = []\n",
135
+ "f1 = []\n",
136
+ "\n",
137
+ "DATA_ROOT = './output'\n",
138
+ "FILE_JP_KO_TEST = 'ja_ko_test.csv'\n",
139
+ "FILE_FFAC_TEST = 'ffac_test.csv'\n",
140
+ "\n",
141
+ "with torch.no_grad(), open(f'{DATA_ROOT}/{FILE_FFAC_TEST}', 'r') as fd:\n",
142
+ "# with torch.no_grad(), open(f'{DATA_ROOT}/{FILE_JP_KO_TEST}', 'r') as fd:\n",
143
+ " reader = csv.reader(fd)\n",
144
+ " next(reader)\n",
145
+ " datas = [row for row in reader] \n",
146
+ "\n",
147
+ " for data in tqdm(datas, \"Testing\"):\n",
148
+ " input, label = data\n",
149
+ " embeddings = src_tokenizer(input, return_attention_mask=False, return_token_type_ids=False, return_tensors='pt')\n",
150
+ " embeddings = {k: v for k, v in embeddings.items()}\n",
151
+ " with torch.no_grad():\n",
152
+ " output = model.generate(**embeddings)[0, 1:-1]\n",
153
+ " preds = trg_tokenizer.decode(output.cpu())\n",
154
+ "\n",
155
+ " bleu.append(sentence_bleu([label.split()], preds.split(), weights=[1,0,0,0], smoothing_function=smoothie))\n",
156
+ "\n",
157
+ "print(f\"Bleu score: {mean(bleu)}\")"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "Python 3",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.8.10"
178
+ },
179
+ "orig_nbformat": 4
180
+ },
181
+ "nbformat": 4,
182
+ "nbformat_minor": 2
183
+ }
training.ipynb ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "The primary codes below are based on [akpe12/JP-KR-ocr-translator-for-travel](https://github.com/akpe12/JP-KR-ocr-translator-for-travel)."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "metadata": {
14
+ "id": "TrHlPFqwFAgj"
15
+ },
16
+ "source": [
17
+ "## Import"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {
24
+ "id": "t-jXeSJKE1WM"
25
+ },
26
+ "outputs": [],
27
+ "source": [
28
+ "\n",
29
+ "from typing import Dict, List\n",
30
+ "import csv\n",
31
+ "import torch\n",
32
+ "from transformers import (\n",
33
+ " EncoderDecoderModel,\n",
34
+ " GPT2Tokenizer as BaseGPT2Tokenizer,\n",
35
+ " PreTrainedTokenizer, BertTokenizerFast,\n",
36
+ " PreTrainedTokenizerFast,\n",
37
+ " DataCollatorForSeq2Seq,\n",
38
+ " Seq2SeqTrainingArguments,\n",
39
+ " AutoTokenizer,\n",
40
+ " XLMRobertaTokenizerFast,\n",
41
+ " BertJapaneseTokenizer,\n",
42
+ " Trainer\n",
43
+ ")\n",
44
+ "from torch.utils.data import DataLoader\n",
45
+ "from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel\n",
46
+ "\n",
47
+ "# encoder_model_name = \"xlm-roberta-base\"\n",
48
+ "encoder_model_name = \"cl-tohoku/bert-base-japanese-v2\"\n",
49
+ "decoder_model_name = \"skt/kogpt2-base-v2\""
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {
56
+ "id": "nEW5trBtbykK"
57
+ },
58
+ "outputs": [],
59
+ "source": [
60
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
61
+ "# device = torch.device(\"cpu\")\n",
62
+ "device, torch.cuda.device_count()"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {
69
+ "id": "5ic7pUUBFU_v"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "class GPT2Tokenizer(PreTrainedTokenizerFast):\n",
74
+ " def build_inputs_with_special_tokens(self, token_ids: List[int]) -> List[int]:\n",
75
+ " return token_ids + [self.eos_token_id] \n",
76
+ "\n",
77
+ "src_tokenizer = BertJapaneseTokenizer.from_pretrained(encoder_model_name)\n",
78
+ "trg_tokenizer = GPT2Tokenizer.from_pretrained(decoder_model_name, bos_token='</s>', eos_token='</s>', unk_token='<unk>',\n",
79
+ " pad_token='<pad>', mask_token='<mask>')"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {
85
+ "id": "DTf4U1fmFQFh"
86
+ },
87
+ "source": [
88
+ "## Data"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "id": "65L4O1c5FLKt"
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "class PairedDataset:\n",
100
+ " def __init__(self, \n",
101
+ " src_tokenizer: PreTrainedTokenizerFast, tgt_tokenizer: PreTrainedTokenizerFast,\n",
102
+ " file_path: str\n",
103
+ " ):\n",
104
+ " self.src_tokenizer = src_tokenizer\n",
105
+ " self.trg_tokenizer = tgt_tokenizer\n",
106
+ " with open(file_path, 'r') as fd:\n",
107
+ " reader = csv.reader(fd)\n",
108
+ " next(reader)\n",
109
+ " self.data = [row for row in reader]\n",
110
+ "\n",
111
+ " def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:\n",
112
+ " src, trg = self.data[index]\n",
113
+ " embeddings = self.src_tokenizer(src, return_attention_mask=False, return_token_type_ids=False)\n",
114
+ " embeddings['labels'] = self.trg_tokenizer.build_inputs_with_special_tokens(self.trg_tokenizer(trg, return_attention_mask=False)['input_ids'])\n",
115
+ "\n",
116
+ " return embeddings\n",
117
+ "\n",
118
+ " def __len__(self):\n",
119
+ " return len(self.data)\n",
120
+ " \n",
121
+ "DATA_ROOT = './output'\n",
122
+ "FILE_FFAC_FULL = 'ffac_full.csv'\n",
123
+ "FILE_FFAC_TEST = 'ffac_test.csv'\n",
124
+ "# FILE_JA_KO_TRAIN = 'ja_ko_train.csv'\n",
125
+ "# FILE_JA_KO_TEST = 'ja_ko_test.csv'\n",
126
+ "\n",
127
+ "train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_FULL}')\n",
128
+ "eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_FFAC_TEST}') \n",
129
+ "# train_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TRAIN}')\n",
130
+ "# eval_dataset = PairedDataset(src_tokenizer, trg_tokenizer, f'{DATA_ROOT}/{FILE_JA_KO_TEST}') "
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {
136
+ "id": "uCBiLouSFiZY"
137
+ },
138
+ "source": [
139
+ "## Model"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {
146
+ "id": "I7uFbFYJFje8"
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "model = EncoderDecoderModel.from_encoder_decoder_pretrained(\n",
151
+ " encoder_model_name,\n",
152
+ " decoder_model_name,\n",
153
+ " pad_token_id=trg_tokenizer.bos_token_id,\n",
154
+ ")\n",
155
+ "model.config.decoder_start_token_id = trg_tokenizer.bos_token_id"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {
162
+ "id": "YFq2GyOAUV0W"
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "# for Trainer\n",
167
+ "import wandb\n",
168
+ "\n",
169
+ "collate_fn = DataCollatorForSeq2Seq(src_tokenizer, model)\n",
170
+ "wandb.init(project=\"fftr-poc1\", name='jbert+kogpt2')\n",
171
+ "\n",
172
+ "arguments = Seq2SeqTrainingArguments(\n",
173
+ " output_dir='dump',\n",
174
+ " do_train=True,\n",
175
+ " do_eval=True,\n",
176
+ " evaluation_strategy=\"epoch\",\n",
177
+ " save_strategy=\"epoch\",\n",
178
+ "# num_train_epochs=5,\n",
179
+ " num_train_epochs=25,\n",
180
+ "# per_device_train_batch_size=32,\n",
181
+ " per_device_train_batch_size=64,\n",
182
+ "# per_device_eval_batch_size=32,\n",
183
+ " per_device_eval_batch_size=64,\n",
184
+ " warmup_ratio=0.1,\n",
185
+ " gradient_accumulation_steps=4,\n",
186
+ " save_total_limit=5,\n",
187
+ " dataloader_num_workers=1,\n",
188
+ " fp16=True,\n",
189
+ " load_best_model_at_end=True,\n",
190
+ " report_to='wandb'\n",
191
+ ")\n",
192
+ "\n",
193
+ "trainer = Trainer(\n",
194
+ " model,\n",
195
+ " arguments,\n",
196
+ " data_collator=collate_fn,\n",
197
+ " train_dataset=train_dataset,\n",
198
+ " eval_dataset=eval_dataset\n",
199
+ ")"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "markdown",
204
+ "metadata": {
205
+ "id": "pPsjDHO5Vc3y"
206
+ },
207
+ "source": [
208
+ "## Training"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {
215
+ "id": "_T4P4XunmK-C"
216
+ },
217
+ "outputs": [],
218
+ "source": [
219
+ "# model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"xlm-roberta-base\", \"skt/kogpt2-base-v2\")"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {
226
+ "id": "7vTqAgW6Ve3J"
227
+ },
228
+ "outputs": [],
229
+ "source": [
230
+ "trainer.train()\n",
231
+ "\n",
232
+ "model.save_pretrained(\"dump/best_model\")"
233
+ ]
234
+ }
235
+ ],
236
+ "metadata": {
237
+ "colab": {
238
+ "machine_shape": "hm",
239
+ "provenance": []
240
+ },
241
+ "gpuClass": "premium",
242
+ "kernelspec": {
243
+ "display_name": "Python 3",
244
+ "name": "python3"
245
+ },
246
+ "language_info": {
247
+ "codemirror_mode": {
248
+ "name": "ipython",
249
+ "version": 3
250
+ },
251
+ "file_extension": ".py",
252
+ "mimetype": "text/x-python",
253
+ "name": "python",
254
+ "nbconvert_exporter": "python",
255
+ "pygments_lexer": "ipython3",
256
+ "version": "3.8.10"
257
+ }
258
+ },
259
+ "nbformat": 4,
260
+ "nbformat_minor": 0
261
+ }