Upload 4 files
Browse files- T5Trainer.py +289 -0
- training_data10k.txt +0 -0
- training_data5K.txt +0 -0
- validation_data.txt +0 -0
T5Trainer.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
|
2 |
+
from transformers import AdamWeightDecay
|
3 |
+
import tensorflow as tf
|
4 |
+
import random
|
5 |
+
from transformers import logging as hf_logging
|
6 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
import numpy as np
|
9 |
+
import textwrap
|
10 |
+
import argparse
|
11 |
+
import re
|
12 |
+
import warnings
|
13 |
+
import os
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
16 |
+
hf_logging.set_verbosity_error()
|
17 |
+
|
18 |
+
np.random.seed(1234)
|
19 |
+
tf.random.set_seed(1234)
|
20 |
+
random.seed(1234)
|
21 |
+
|
22 |
+
|
23 |
+
def create_arg_parser():
|
24 |
+
'''Creating command line arguments'''
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
|
27 |
+
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
|
28 |
+
type=str, help="this argument takes the pretrained "
|
29 |
+
"language model URL from HuggingFace "
|
30 |
+
"default is HateBERT, please visit "
|
31 |
+
"HuggingFace for full URL")
|
32 |
+
parser.add_argument("-c_model", "--custom_model",
|
33 |
+
type=str, help="this argument takes a custom "
|
34 |
+
"pretrained checkpoint")
|
35 |
+
parser.add_argument("-train", "--train_data", default='training_data.txt',
|
36 |
+
type=str, help="this argument takes the train "
|
37 |
+
"data file as input")
|
38 |
+
parser.add_argument("-dev", "--dev_data", default='dev_data.txt', type=str,
|
39 |
+
help="this argument takes the dev data file as "
|
40 |
+
"input")
|
41 |
+
parser.add_argument("-sample_weight", "--sample_weight", type=str,
|
42 |
+
help="class weights for custom loss calculation")
|
43 |
+
parser.add_argument("-lr", "--learn_rate", default=1e-3, type=float,
|
44 |
+
help="Set a custom learn rate for "
|
45 |
+
"the pretrained language model, default is 5e-5")
|
46 |
+
parser.add_argument("-bs", "--batch_size", default=16, type=int,
|
47 |
+
help="Set a custom batch size for "
|
48 |
+
"the pretrained language model, default is 8")
|
49 |
+
parser.add_argument("-sl_train", "--sequence_length_train", default=100,
|
50 |
+
type=int, help="Set a custom maximum sequence length"
|
51 |
+
"for the pretrained language model,"
|
52 |
+
"default is 100")
|
53 |
+
parser.add_argument("-sl_dev", "--sequence_length_dev", default=100,
|
54 |
+
type=int, help="Set a custom maximum sequence length"
|
55 |
+
"for the pretrained language model,"
|
56 |
+
"default is 100")
|
57 |
+
parser.add_argument("-ep", "--epochs", default=1, type=int,
|
58 |
+
help="This argument selects the amount of epochs "
|
59 |
+
"to run the model with, default is 1 epoch")
|
60 |
+
parser.add_argument("-es", "--early_stop", default="val_loss", type=str,
|
61 |
+
help="Set the value to monitor for earlystopping")
|
62 |
+
parser.add_argument("-es_p", "--early_stop_patience", default=2,
|
63 |
+
type=int, help="Set the patience value for "
|
64 |
+
"earlystopping")
|
65 |
+
args = parser.parse_args()
|
66 |
+
return args
|
67 |
+
|
68 |
+
|
69 |
+
def read_data(data_file):
|
70 |
+
'''Reading in data files'''
|
71 |
+
with open(data_file) as file:
|
72 |
+
data = file.readlines()
|
73 |
+
|
74 |
+
text = []
|
75 |
+
for d in data:
|
76 |
+
text.append(d)
|
77 |
+
return text
|
78 |
+
|
79 |
+
|
80 |
+
def create_data(data):
|
81 |
+
'''Splitting Alpino format training data into separate
|
82 |
+
source and target sentences'''
|
83 |
+
source_text = []
|
84 |
+
target_text = []
|
85 |
+
for x in data:
|
86 |
+
source = []
|
87 |
+
target = []
|
88 |
+
spel = re.findall(r'\[.*?\]', x)
|
89 |
+
if spel:
|
90 |
+
for s in spel:
|
91 |
+
s = s.split()
|
92 |
+
if s[1] == '@alt':
|
93 |
+
target.append(''.join(s[2:3]))
|
94 |
+
source.append(''.join(s[3:-1]))
|
95 |
+
elif s[1] == '@mwu_alt':
|
96 |
+
target.append(''.join(s[2:3]))
|
97 |
+
source.append(''.join(s[3:-1]).replace('-', ''))
|
98 |
+
elif s[1] == '@mwu':
|
99 |
+
target.append(''.join(s[2:-1]))
|
100 |
+
source.append(' '.join(s[2:-1]))
|
101 |
+
elif s[1] == '@postag':
|
102 |
+
target.append(''.join(s[-2]))
|
103 |
+
source.append(''.join(s[-2]))
|
104 |
+
elif s[1] == '@phantom':
|
105 |
+
target.append(''.join(s[2]))
|
106 |
+
source.append('')
|
107 |
+
|
108 |
+
target2 = []
|
109 |
+
for t in target:
|
110 |
+
if t[0] == '~':
|
111 |
+
t = t.split('~')
|
112 |
+
target2.append(t[1])
|
113 |
+
else:
|
114 |
+
target2.append(t)
|
115 |
+
|
116 |
+
sent = re.sub(r'\[.*?\]', 'EMPTY', x)
|
117 |
+
word_c = 0
|
118 |
+
src = []
|
119 |
+
trg = []
|
120 |
+
for word in sent.split():
|
121 |
+
if word == 'EMPTY':
|
122 |
+
src.append(source[word_c])
|
123 |
+
trg.append(target2[word_c])
|
124 |
+
word_c += 1
|
125 |
+
else:
|
126 |
+
src.append(word)
|
127 |
+
trg.append(word)
|
128 |
+
source_text.append(' '.join(src))
|
129 |
+
target_text.append(' '.join(trg))
|
130 |
+
return source_text, target_text
|
131 |
+
|
132 |
+
|
133 |
+
def split_sent(data, max_length):
|
134 |
+
'''Splitting sentences if longer than given threshold'''
|
135 |
+
short_sent = []
|
136 |
+
long_sent = []
|
137 |
+
for n in data:
|
138 |
+
n = n.split('|')
|
139 |
+
if len(n[1]) <= max_length:
|
140 |
+
short_sent.append(n[1])
|
141 |
+
elif len(n[1]) > max_length:
|
142 |
+
n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '$$', n[1])
|
143 |
+
n[1] = n[1].replace("] [", "]##[")
|
144 |
+
lines = textwrap.wrap(n[1], max_length, break_long_words=False)
|
145 |
+
long_sent.append(lines)
|
146 |
+
|
147 |
+
new_data = []
|
148 |
+
for s in long_sent:
|
149 |
+
for s1 in s:
|
150 |
+
s1 = s1.replace(']##[', '] [')
|
151 |
+
s1 = s1.replace('$$', ' ')
|
152 |
+
s2 = s1.split()
|
153 |
+
if len(s2) > 2:
|
154 |
+
new_data.append(s1)
|
155 |
+
|
156 |
+
for x in short_sent:
|
157 |
+
new_data.append(x)
|
158 |
+
return new_data
|
159 |
+
|
160 |
+
|
161 |
+
def preprocess_function(tk, s, t):
|
162 |
+
'''tokenizing data text and labels'''
|
163 |
+
model_inputs = tk(s)
|
164 |
+
|
165 |
+
with tk.as_target_tokenizer():
|
166 |
+
labels = tk(t)
|
167 |
+
|
168 |
+
model_inputs["labels"] = labels["input_ids"]
|
169 |
+
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
170 |
+
return model_inputs
|
171 |
+
|
172 |
+
|
173 |
+
def convert_tok(tok, sl):
|
174 |
+
'''Convert tokenized object to Tensors and add padding'''
|
175 |
+
input_ids = []
|
176 |
+
attention_mask = []
|
177 |
+
labels = []
|
178 |
+
decoder_attention_mask = []
|
179 |
+
for a, b, c, d in zip(tok['input_ids'], tok['attention_mask'], tok['labels'],
|
180 |
+
tok['decoder_attention_mask']):
|
181 |
+
input_ids.append(a)
|
182 |
+
attention_mask.append(b)
|
183 |
+
labels.append(c)
|
184 |
+
decoder_attention_mask.append(d)
|
185 |
+
|
186 |
+
input_ids_pad = pad_sequences(input_ids, padding='post', maxlen=sl)
|
187 |
+
attention_mask_pad = pad_sequences(attention_mask, padding='post',
|
188 |
+
maxlen=sl)
|
189 |
+
labels_pad = pad_sequences(labels, padding='post', maxlen=sl)
|
190 |
+
dec_attention_mask_pad = pad_sequences(decoder_attention_mask,
|
191 |
+
padding='post', maxlen=sl)
|
192 |
+
return {'input_ids': tf.constant(input_ids_pad), 'attention_mask':
|
193 |
+
tf.constant(attention_mask_pad), 'labels': tf.constant(labels_pad),
|
194 |
+
'decoder_attention_mask': tf.constant(dec_attention_mask_pad)}
|
195 |
+
|
196 |
+
|
197 |
+
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
|
198 |
+
'''Finetune and save a given T5 version with given (hyper)parameters'''
|
199 |
+
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
|
200 |
+
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
|
201 |
+
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
|
202 |
+
|
203 |
+
tk = AutoTokenizer.from_pretrained(model_name)
|
204 |
+
|
205 |
+
args = create_arg_parser()
|
206 |
+
source_train, target_train = create_data(train)
|
207 |
+
source_test, target_test = create_data(dev)
|
208 |
+
|
209 |
+
if args.custom_model:
|
210 |
+
model = TFAutoModelForSeq2SeqLM.from_pretrained(args.custom_model,
|
211 |
+
from_pt=True)
|
212 |
+
else:
|
213 |
+
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
214 |
+
|
215 |
+
train_tok = preprocess_function(tk, source_train, target_train)
|
216 |
+
dev_tok = preprocess_function(tk, source_test, target_test)
|
217 |
+
|
218 |
+
tf_train = convert_tok(train_tok, sl_train)
|
219 |
+
tf_dev = convert_tok(dev_tok, sl_dev)
|
220 |
+
|
221 |
+
optim = AdamWeightDecay(learning_rate=lr)
|
222 |
+
model.compile(optimizer=optim, loss=custom_loss,
|
223 |
+
metrics=[accuracy])
|
224 |
+
ear_stop = tf.keras.callbacks.EarlyStopping(monitor=es, patience=es_p,
|
225 |
+
restore_best_weights=True,
|
226 |
+
mode="auto")
|
227 |
+
model.fit(tf_train, validation_data=tf_dev, epochs=ep,
|
228 |
+
batch_size=bs, callbacks=[ear_stop])
|
229 |
+
model.save_weights('{}_weights.h5'.format(model_name[7:]))
|
230 |
+
return model
|
231 |
+
|
232 |
+
|
233 |
+
def custom_loss(y_true, y_pred):
|
234 |
+
'''Custom loss function'''
|
235 |
+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
236 |
+
from_logits=True, reduction='none')
|
237 |
+
loss = loss_fn(y_true, y_pred)
|
238 |
+
|
239 |
+
mask = tf.cast(y_true != 0, loss.dtype)
|
240 |
+
loss *= mask
|
241 |
+
return tf.reduce_sum(loss)/tf.reduce_sum(mask)
|
242 |
+
|
243 |
+
|
244 |
+
def accuracy(y_true, y_pred):
|
245 |
+
'''Custom accuracy function '''
|
246 |
+
y_pred = tf.argmax(y_pred, axis=-1)
|
247 |
+
y_pred = tf.cast(y_pred, y_true.dtype)
|
248 |
+
|
249 |
+
match = tf.cast(y_true == y_pred, tf.float32)
|
250 |
+
mask = tf.cast(y_true != 0, tf.float32)
|
251 |
+
return tf.reduce_sum(match)/tf.reduce_sum(mask)
|
252 |
+
|
253 |
+
|
254 |
+
def main():
|
255 |
+
args = create_arg_parser()
|
256 |
+
|
257 |
+
lr = args.learn_rate
|
258 |
+
bs = args.batch_size
|
259 |
+
sl_train = args.sequence_length_train
|
260 |
+
sl_dev = args.sequence_length_dev
|
261 |
+
split_length_train = (sl_train - 5)
|
262 |
+
split_length_dev = (sl_dev - 5)
|
263 |
+
ep = args.epochs
|
264 |
+
|
265 |
+
if args.transformer == 'google/flan-t5-small':
|
266 |
+
model_name = 'google/flan-t5-small'
|
267 |
+
elif args.transformer == 'google/byt5-small':
|
268 |
+
model_name = 'google/byt5-small'
|
269 |
+
elif args.transformer == 'google/mt5-small':
|
270 |
+
model_name = 'google/mt5-small'
|
271 |
+
else:
|
272 |
+
model_name = 'Unknown'
|
273 |
+
|
274 |
+
early_stop = args.early_stop
|
275 |
+
patience = args.early_stop_patience
|
276 |
+
|
277 |
+
train_d = read_data(args.train_data)
|
278 |
+
dev_d = read_data(args.dev_data)
|
279 |
+
train_data = split_sent(train_d, split_length_train)
|
280 |
+
dev_data = split_sent(dev_d, split_length_dev)
|
281 |
+
|
282 |
+
print('Train size: {}\nDev size: {}\n'.format(len(train_data),
|
283 |
+
len(dev_data)))
|
284 |
+
print(train_model(model_name, lr, bs, sl_train, sl_dev,
|
285 |
+
ep, early_stop, patience, train_data, dev_data))
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
main()
|
training_data10k.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
training_data5K.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
validation_data.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|