Text2Text Generation
Transformers
PyTorch
Vietnamese
mt5
Inference Endpoints
nluai commited on
Commit
1a8640b
1 Parent(s): 20d64f4

Delete train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +0 -164
train_script.py DELETED
@@ -1,164 +0,0 @@
1
- import argparse
2
- import logging
3
- from torch.utils.data import Dataset, IterableDataset
4
- import gzip
5
- import json
6
- from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
7
- import sys
8
- from datetime import datetime
9
- import torch
10
- import random
11
- from shutil import copyfile
12
- import os
13
- import wandb
14
- import random
15
- import re
16
- from datasets import load_dataset
17
- import tqdm
18
-
19
-
20
- logging.basicConfig(
21
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
22
- datefmt="%Y-%m-%d %H:%M:%S",
23
- handlers=[logging.StreamHandler(sys.stdout)],
24
- )
25
-
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--lang", required=True)
28
- parser.add_argument("--model_name", default="google/mt5-base")
29
- parser.add_argument("--epochs", default=4, type=int)
30
- parser.add_argument("--batch_size", default=32, type=int)
31
- parser.add_argument("--max_source_length", default=320, type=int)
32
- parser.add_argument("--max_target_length", default=64, type=int)
33
- parser.add_argument("--eval_size", default=1000, type=int)
34
- #parser.add_argument("--fp16", default=False, action='store_true')
35
- args = parser.parse_args()
36
-
37
- wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
38
-
39
-
40
-
41
-
42
-
43
- def main():
44
- ############ Load dataset
45
- queries = {}
46
- for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
47
- queries[row['id']] = row['text']
48
-
49
- """
50
- collection = {}
51
- for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
52
- collection[row['id']] = row['text']
53
- """
54
- collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
55
-
56
- train_pairs = []
57
- eval_pairs = []
58
-
59
-
60
- with open('qrels.train.tsv') as fIn:
61
- for line in fIn:
62
- qid, _, did, _ = line.strip().split("\t")
63
-
64
- qid = int(qid)
65
- did = int(did)
66
-
67
- assert did == collection[did]['id']
68
- text = collection[did]['text']
69
-
70
- pair = (queries[qid], text)
71
- if len(eval_pairs) < args.eval_size:
72
- eval_pairs.append(pair)
73
- else:
74
- train_pairs.append(pair)
75
-
76
-
77
- print(f"Train pairs: {len(train_pairs)}")
78
-
79
-
80
- ############ Model
81
- model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
82
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
83
-
84
- save_steps = 1000
85
-
86
- output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
87
- print("Output dir:", output_dir)
88
-
89
- # Write self to path
90
- os.makedirs(output_dir, exist_ok=True)
91
-
92
- train_script_path = os.path.join(output_dir, 'train_script.py')
93
- copyfile(__file__, train_script_path)
94
- with open(train_script_path, 'a') as fOut:
95
- fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
96
-
97
- ####
98
-
99
- training_args = Seq2SeqTrainingArguments(
100
- output_dir=output_dir,
101
- bf16=True,
102
- per_device_train_batch_size=args.batch_size,
103
- evaluation_strategy="steps",
104
- save_steps=save_steps,
105
- logging_steps=100,
106
- eval_steps=save_steps, #logging_steps,
107
- warmup_steps=1000,
108
- save_total_limit=1,
109
- num_train_epochs=args.epochs,
110
- report_to="wandb",
111
- )
112
-
113
- ############ Arguments
114
-
115
- ############ Load datasets
116
-
117
-
118
- print("Input:", train_pairs[0][1])
119
- print("Target:", train_pairs[0][0])
120
-
121
- print("Input:", eval_pairs[0][1])
122
- print("Target:", eval_pairs[0][0])
123
-
124
-
125
- def data_collator(examples):
126
- targets = [row[0] for row in examples]
127
- inputs = [row[1] for row in examples]
128
- label_pad_token_id = -100
129
-
130
- model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
131
-
132
- # Setup the tokenizer for targets
133
- with tokenizer.as_target_tokenizer():
134
- labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
135
-
136
- # replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
137
- labels["input_ids"] = [
138
- [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
139
- ]
140
-
141
-
142
- model_inputs["labels"] = torch.tensor(labels["input_ids"])
143
- return model_inputs
144
-
145
- ## Define the trainer
146
- trainer = Seq2SeqTrainer(
147
- model=model,
148
- args=training_args,
149
- train_dataset=train_pairs,
150
- eval_dataset=eval_pairs,
151
- tokenizer=tokenizer,
152
- data_collator=data_collator
153
- )
154
-
155
- ### Save the model
156
- train_result = trainer.train()
157
- trainer.save_model()
158
-
159
-
160
- if __name__ == "__main__":
161
- main()
162
-
163
- # Script was called via:
164
- #python train_hf_trainer_multilingual.py --lang vietnamese