Delete train_script.py
Browse files- train_script.py +0 -164
train_script.py
DELETED
@@ -1,164 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import logging
|
3 |
-
from torch.utils.data import Dataset, IterableDataset
|
4 |
-
import gzip
|
5 |
-
import json
|
6 |
-
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
|
7 |
-
import sys
|
8 |
-
from datetime import datetime
|
9 |
-
import torch
|
10 |
-
import random
|
11 |
-
from shutil import copyfile
|
12 |
-
import os
|
13 |
-
import wandb
|
14 |
-
import random
|
15 |
-
import re
|
16 |
-
from datasets import load_dataset
|
17 |
-
import tqdm
|
18 |
-
|
19 |
-
|
20 |
-
logging.basicConfig(
|
21 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
22 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
-
handlers=[logging.StreamHandler(sys.stdout)],
|
24 |
-
)
|
25 |
-
|
26 |
-
parser = argparse.ArgumentParser()
|
27 |
-
parser.add_argument("--lang", required=True)
|
28 |
-
parser.add_argument("--model_name", default="google/mt5-base")
|
29 |
-
parser.add_argument("--epochs", default=4, type=int)
|
30 |
-
parser.add_argument("--batch_size", default=32, type=int)
|
31 |
-
parser.add_argument("--max_source_length", default=320, type=int)
|
32 |
-
parser.add_argument("--max_target_length", default=64, type=int)
|
33 |
-
parser.add_argument("--eval_size", default=1000, type=int)
|
34 |
-
#parser.add_argument("--fp16", default=False, action='store_true')
|
35 |
-
args = parser.parse_args()
|
36 |
-
|
37 |
-
wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
def main():
|
44 |
-
############ Load dataset
|
45 |
-
queries = {}
|
46 |
-
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
|
47 |
-
queries[row['id']] = row['text']
|
48 |
-
|
49 |
-
"""
|
50 |
-
collection = {}
|
51 |
-
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
|
52 |
-
collection[row['id']] = row['text']
|
53 |
-
"""
|
54 |
-
collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
|
55 |
-
|
56 |
-
train_pairs = []
|
57 |
-
eval_pairs = []
|
58 |
-
|
59 |
-
|
60 |
-
with open('qrels.train.tsv') as fIn:
|
61 |
-
for line in fIn:
|
62 |
-
qid, _, did, _ = line.strip().split("\t")
|
63 |
-
|
64 |
-
qid = int(qid)
|
65 |
-
did = int(did)
|
66 |
-
|
67 |
-
assert did == collection[did]['id']
|
68 |
-
text = collection[did]['text']
|
69 |
-
|
70 |
-
pair = (queries[qid], text)
|
71 |
-
if len(eval_pairs) < args.eval_size:
|
72 |
-
eval_pairs.append(pair)
|
73 |
-
else:
|
74 |
-
train_pairs.append(pair)
|
75 |
-
|
76 |
-
|
77 |
-
print(f"Train pairs: {len(train_pairs)}")
|
78 |
-
|
79 |
-
|
80 |
-
############ Model
|
81 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
82 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
83 |
-
|
84 |
-
save_steps = 1000
|
85 |
-
|
86 |
-
output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
87 |
-
print("Output dir:", output_dir)
|
88 |
-
|
89 |
-
# Write self to path
|
90 |
-
os.makedirs(output_dir, exist_ok=True)
|
91 |
-
|
92 |
-
train_script_path = os.path.join(output_dir, 'train_script.py')
|
93 |
-
copyfile(__file__, train_script_path)
|
94 |
-
with open(train_script_path, 'a') as fOut:
|
95 |
-
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
96 |
-
|
97 |
-
####
|
98 |
-
|
99 |
-
training_args = Seq2SeqTrainingArguments(
|
100 |
-
output_dir=output_dir,
|
101 |
-
bf16=True,
|
102 |
-
per_device_train_batch_size=args.batch_size,
|
103 |
-
evaluation_strategy="steps",
|
104 |
-
save_steps=save_steps,
|
105 |
-
logging_steps=100,
|
106 |
-
eval_steps=save_steps, #logging_steps,
|
107 |
-
warmup_steps=1000,
|
108 |
-
save_total_limit=1,
|
109 |
-
num_train_epochs=args.epochs,
|
110 |
-
report_to="wandb",
|
111 |
-
)
|
112 |
-
|
113 |
-
############ Arguments
|
114 |
-
|
115 |
-
############ Load datasets
|
116 |
-
|
117 |
-
|
118 |
-
print("Input:", train_pairs[0][1])
|
119 |
-
print("Target:", train_pairs[0][0])
|
120 |
-
|
121 |
-
print("Input:", eval_pairs[0][1])
|
122 |
-
print("Target:", eval_pairs[0][0])
|
123 |
-
|
124 |
-
|
125 |
-
def data_collator(examples):
|
126 |
-
targets = [row[0] for row in examples]
|
127 |
-
inputs = [row[1] for row in examples]
|
128 |
-
label_pad_token_id = -100
|
129 |
-
|
130 |
-
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
|
131 |
-
|
132 |
-
# Setup the tokenizer for targets
|
133 |
-
with tokenizer.as_target_tokenizer():
|
134 |
-
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
|
135 |
-
|
136 |
-
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
|
137 |
-
labels["input_ids"] = [
|
138 |
-
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
|
139 |
-
]
|
140 |
-
|
141 |
-
|
142 |
-
model_inputs["labels"] = torch.tensor(labels["input_ids"])
|
143 |
-
return model_inputs
|
144 |
-
|
145 |
-
## Define the trainer
|
146 |
-
trainer = Seq2SeqTrainer(
|
147 |
-
model=model,
|
148 |
-
args=training_args,
|
149 |
-
train_dataset=train_pairs,
|
150 |
-
eval_dataset=eval_pairs,
|
151 |
-
tokenizer=tokenizer,
|
152 |
-
data_collator=data_collator
|
153 |
-
)
|
154 |
-
|
155 |
-
### Save the model
|
156 |
-
train_result = trainer.train()
|
157 |
-
trainer.save_model()
|
158 |
-
|
159 |
-
|
160 |
-
if __name__ == "__main__":
|
161 |
-
main()
|
162 |
-
|
163 |
-
# Script was called via:
|
164 |
-
#python train_hf_trainer_multilingual.py --lang vietnamese
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|