AWolters commited on
Commit
4403157
1 Parent(s): 67c1245

Upload 4 files

Browse files
T5Trainer.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
2
+ from transformers import AdamWeightDecay
3
+ import tensorflow as tf
4
+ import random
5
+ from transformers import logging as hf_logging
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ from sklearn.model_selection import train_test_split
8
+ import numpy as np
9
+ import textwrap
10
+ import argparse
11
+ import re
12
+ import warnings
13
+ import os
14
+ warnings.filterwarnings("ignore")
15
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
16
+ hf_logging.set_verbosity_error()
17
+
18
+ np.random.seed(1234)
19
+ tf.random.set_seed(1234)
20
+ random.seed(1234)
21
+
22
+
23
+ def create_arg_parser():
24
+ '''Creating command line arguments'''
25
+ parser = argparse.ArgumentParser()
26
+
27
+ parser.add_argument("-tf", "--transformer", default="google/byt5-small",
28
+ type=str, help="this argument takes the pretrained "
29
+ "language model URL from HuggingFace "
30
+ "default is HateBERT, please visit "
31
+ "HuggingFace for full URL")
32
+ parser.add_argument("-c_model", "--custom_model",
33
+ type=str, help="this argument takes a custom "
34
+ "pretrained checkpoint")
35
+ parser.add_argument("-train", "--train_data", default='training_data.txt',
36
+ type=str, help="this argument takes the train "
37
+ "data file as input")
38
+ parser.add_argument("-dev", "--dev_data", default='dev_data.txt', type=str,
39
+ help="this argument takes the dev data file as "
40
+ "input")
41
+ parser.add_argument("-sample_weight", "--sample_weight", type=str,
42
+ help="class weights for custom loss calculation")
43
+ parser.add_argument("-lr", "--learn_rate", default=1e-3, type=float,
44
+ help="Set a custom learn rate for "
45
+ "the pretrained language model, default is 5e-5")
46
+ parser.add_argument("-bs", "--batch_size", default=16, type=int,
47
+ help="Set a custom batch size for "
48
+ "the pretrained language model, default is 8")
49
+ parser.add_argument("-sl_train", "--sequence_length_train", default=100,
50
+ type=int, help="Set a custom maximum sequence length"
51
+ "for the pretrained language model,"
52
+ "default is 100")
53
+ parser.add_argument("-sl_dev", "--sequence_length_dev", default=100,
54
+ type=int, help="Set a custom maximum sequence length"
55
+ "for the pretrained language model,"
56
+ "default is 100")
57
+ parser.add_argument("-ep", "--epochs", default=1, type=int,
58
+ help="This argument selects the amount of epochs "
59
+ "to run the model with, default is 1 epoch")
60
+ parser.add_argument("-es", "--early_stop", default="val_loss", type=str,
61
+ help="Set the value to monitor for earlystopping")
62
+ parser.add_argument("-es_p", "--early_stop_patience", default=2,
63
+ type=int, help="Set the patience value for "
64
+ "earlystopping")
65
+ args = parser.parse_args()
66
+ return args
67
+
68
+
69
+ def read_data(data_file):
70
+ '''Reading in data files'''
71
+ with open(data_file) as file:
72
+ data = file.readlines()
73
+
74
+ text = []
75
+ for d in data:
76
+ text.append(d)
77
+ return text
78
+
79
+
80
+ def create_data(data):
81
+ '''Splitting Alpino format training data into separate
82
+ source and target sentences'''
83
+ source_text = []
84
+ target_text = []
85
+ for x in data:
86
+ source = []
87
+ target = []
88
+ spel = re.findall(r'\[.*?\]', x)
89
+ if spel:
90
+ for s in spel:
91
+ s = s.split()
92
+ if s[1] == '@alt':
93
+ target.append(''.join(s[2:3]))
94
+ source.append(''.join(s[3:-1]))
95
+ elif s[1] == '@mwu_alt':
96
+ target.append(''.join(s[2:3]))
97
+ source.append(''.join(s[3:-1]).replace('-', ''))
98
+ elif s[1] == '@mwu':
99
+ target.append(''.join(s[2:-1]))
100
+ source.append(' '.join(s[2:-1]))
101
+ elif s[1] == '@postag':
102
+ target.append(''.join(s[-2]))
103
+ source.append(''.join(s[-2]))
104
+ elif s[1] == '@phantom':
105
+ target.append(''.join(s[2]))
106
+ source.append('')
107
+
108
+ target2 = []
109
+ for t in target:
110
+ if t[0] == '~':
111
+ t = t.split('~')
112
+ target2.append(t[1])
113
+ else:
114
+ target2.append(t)
115
+
116
+ sent = re.sub(r'\[.*?\]', 'EMPTY', x)
117
+ word_c = 0
118
+ src = []
119
+ trg = []
120
+ for word in sent.split():
121
+ if word == 'EMPTY':
122
+ src.append(source[word_c])
123
+ trg.append(target2[word_c])
124
+ word_c += 1
125
+ else:
126
+ src.append(word)
127
+ trg.append(word)
128
+ source_text.append(' '.join(src))
129
+ target_text.append(' '.join(trg))
130
+ return source_text, target_text
131
+
132
+
133
+ def split_sent(data, max_length):
134
+ '''Splitting sentences if longer than given threshold'''
135
+ short_sent = []
136
+ long_sent = []
137
+ for n in data:
138
+ n = n.split('|')
139
+ if len(n[1]) <= max_length:
140
+ short_sent.append(n[1])
141
+ elif len(n[1]) > max_length:
142
+ n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '$$', n[1])
143
+ n[1] = n[1].replace("] [", "]##[")
144
+ lines = textwrap.wrap(n[1], max_length, break_long_words=False)
145
+ long_sent.append(lines)
146
+
147
+ new_data = []
148
+ for s in long_sent:
149
+ for s1 in s:
150
+ s1 = s1.replace(']##[', '] [')
151
+ s1 = s1.replace('$$', ' ')
152
+ s2 = s1.split()
153
+ if len(s2) > 2:
154
+ new_data.append(s1)
155
+
156
+ for x in short_sent:
157
+ new_data.append(x)
158
+ return new_data
159
+
160
+
161
+ def preprocess_function(tk, s, t):
162
+ '''tokenizing data text and labels'''
163
+ model_inputs = tk(s)
164
+
165
+ with tk.as_target_tokenizer():
166
+ labels = tk(t)
167
+
168
+ model_inputs["labels"] = labels["input_ids"]
169
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
170
+ return model_inputs
171
+
172
+
173
+ def convert_tok(tok, sl):
174
+ '''Convert tokenized object to Tensors and add padding'''
175
+ input_ids = []
176
+ attention_mask = []
177
+ labels = []
178
+ decoder_attention_mask = []
179
+ for a, b, c, d in zip(tok['input_ids'], tok['attention_mask'], tok['labels'],
180
+ tok['decoder_attention_mask']):
181
+ input_ids.append(a)
182
+ attention_mask.append(b)
183
+ labels.append(c)
184
+ decoder_attention_mask.append(d)
185
+
186
+ input_ids_pad = pad_sequences(input_ids, padding='post', maxlen=sl)
187
+ attention_mask_pad = pad_sequences(attention_mask, padding='post',
188
+ maxlen=sl)
189
+ labels_pad = pad_sequences(labels, padding='post', maxlen=sl)
190
+ dec_attention_mask_pad = pad_sequences(decoder_attention_mask,
191
+ padding='post', maxlen=sl)
192
+ return {'input_ids': tf.constant(input_ids_pad), 'attention_mask':
193
+ tf.constant(attention_mask_pad), 'labels': tf.constant(labels_pad),
194
+ 'decoder_attention_mask': tf.constant(dec_attention_mask_pad)}
195
+
196
+
197
+ def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
198
+ '''Finetune and save a given T5 version with given (hyper)parameters'''
199
+ print('Training model: {}\nWith parameters:\nLearn rate: {}, '
200
+ 'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
201
+ 'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
202
+
203
+ tk = AutoTokenizer.from_pretrained(model_name)
204
+
205
+ args = create_arg_parser()
206
+ source_train, target_train = create_data(train)
207
+ source_test, target_test = create_data(dev)
208
+
209
+ if args.custom_model:
210
+ model = TFAutoModelForSeq2SeqLM.from_pretrained(args.custom_model,
211
+ from_pt=True)
212
+ else:
213
+ model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
214
+
215
+ train_tok = preprocess_function(tk, source_train, target_train)
216
+ dev_tok = preprocess_function(tk, source_test, target_test)
217
+
218
+ tf_train = convert_tok(train_tok, sl_train)
219
+ tf_dev = convert_tok(dev_tok, sl_dev)
220
+
221
+ optim = AdamWeightDecay(learning_rate=lr)
222
+ model.compile(optimizer=optim, loss=custom_loss,
223
+ metrics=[accuracy])
224
+ ear_stop = tf.keras.callbacks.EarlyStopping(monitor=es, patience=es_p,
225
+ restore_best_weights=True,
226
+ mode="auto")
227
+ model.fit(tf_train, validation_data=tf_dev, epochs=ep,
228
+ batch_size=bs, callbacks=[ear_stop])
229
+ model.save_weights('{}_weights.h5'.format(model_name[7:]))
230
+ return model
231
+
232
+
233
+ def custom_loss(y_true, y_pred):
234
+ '''Custom loss function'''
235
+ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
236
+ from_logits=True, reduction='none')
237
+ loss = loss_fn(y_true, y_pred)
238
+
239
+ mask = tf.cast(y_true != 0, loss.dtype)
240
+ loss *= mask
241
+ return tf.reduce_sum(loss)/tf.reduce_sum(mask)
242
+
243
+
244
+ def accuracy(y_true, y_pred):
245
+ '''Custom accuracy function '''
246
+ y_pred = tf.argmax(y_pred, axis=-1)
247
+ y_pred = tf.cast(y_pred, y_true.dtype)
248
+
249
+ match = tf.cast(y_true == y_pred, tf.float32)
250
+ mask = tf.cast(y_true != 0, tf.float32)
251
+ return tf.reduce_sum(match)/tf.reduce_sum(mask)
252
+
253
+
254
+ def main():
255
+ args = create_arg_parser()
256
+
257
+ lr = args.learn_rate
258
+ bs = args.batch_size
259
+ sl_train = args.sequence_length_train
260
+ sl_dev = args.sequence_length_dev
261
+ split_length_train = (sl_train - 5)
262
+ split_length_dev = (sl_dev - 5)
263
+ ep = args.epochs
264
+
265
+ if args.transformer == 'google/flan-t5-small':
266
+ model_name = 'google/flan-t5-small'
267
+ elif args.transformer == 'google/byt5-small':
268
+ model_name = 'google/byt5-small'
269
+ elif args.transformer == 'google/mt5-small':
270
+ model_name = 'google/mt5-small'
271
+ else:
272
+ model_name = 'Unknown'
273
+
274
+ early_stop = args.early_stop
275
+ patience = args.early_stop_patience
276
+
277
+ train_d = read_data(args.train_data)
278
+ dev_d = read_data(args.dev_data)
279
+ train_data = split_sent(train_d, split_length_train)
280
+ dev_data = split_sent(dev_d, split_length_dev)
281
+
282
+ print('Train size: {}\nDev size: {}\n'.format(len(train_data),
283
+ len(dev_data)))
284
+ print(train_model(model_name, lr, bs, sl_train, sl_dev,
285
+ ep, early_stop, patience, train_data, dev_data))
286
+
287
+
288
+ if __name__ == '__main__':
289
+ main()
training_data10k.txt ADDED
The diff for this file is too large to render. See raw diff
 
training_data5K.txt ADDED
The diff for this file is too large to render. See raw diff
 
validation_data.txt ADDED
The diff for this file is too large to render. See raw diff