from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM from transformers import AdamWeightDecay import tensorflow as tf import random from transformers import logging as hf_logging from tensorflow.keras.preprocessing.sequence import pad_sequences from sklearn.model_selection import train_test_split import numpy as np import textwrap import argparse import re import warnings import os warnings.filterwarnings("ignore") os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' hf_logging.set_verbosity_error() np.random.seed(1234) tf.random.set_seed(1234) random.seed(1234) def create_arg_parser(): '''Creating command line arguments''' parser = argparse.ArgumentParser() parser.add_argument("-tf", "--transformer", default="google/byt5-small", type=str, help="this argument takes the pretrained " "language model URL from HuggingFace " "default is ByT5-small, please visit " "HuggingFace for full URL") parser.add_argument("-c_model", "--custom_model", type=str, help="this argument takes a custom " "pretrained checkpoint") parser.add_argument("-train", "--train_data", default='training_data10k.txt', type=str, help="this argument takes the train " "data file as input") parser.add_argument("-dev", "--dev_data", default='validation_data.txt', type=str, help="this argument takes the dev data file " "as input") parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float, help="Set a custom learn rate for " "the model, default is 5e-5") parser.add_argument("-bs", "--batch_size", default=8, type=int, help="Set a custom batch size for " "the pretrained language model, default is 8") parser.add_argument("-sl_train", "--sequence_length_train", default=155, type=int, help="Set a custom maximum sequence length" "for the pretrained language model," "default is 155") parser.add_argument("-sl_dev", "--sequence_length_dev", default=155, type=int, help="Set a custom maximum sequence length" "for the pretrained language model," "default is 155") parser.add_argument("-ep", "--epochs", default=1, type=int, help="This argument selects the amount of epochs " "to run the model with, default is 1 epoch") parser.add_argument("-es", "--early_stop", default="val_loss", type=str, help="Set the value to monitor for earlystopping") parser.add_argument("-es_p", "--early_stop_patience", default=2, type=int, help="Set the patience value for " "earlystopping, default is 2") args = parser.parse_args() return args def read_data(data_file): '''Reading in data files''' with open(data_file) as file: data = file.readlines() text = [] for d in data: text.append(d) return text def create_data(data): '''Splitting Alpino format training data into separate source and target sentences''' source_text = [] target_text = [] for x in data: source = [] target = [] spel = re.findall(r'\[.*?\]', x) if spel: for s in spel: s = s.split() if s[1] == '@alt': target.append(''.join(s[2:3])) source.append(''.join(s[3:-1])) elif s[1] == '@mwu_alt': target.append(''.join(s[2:3])) source.append(''.join(s[3:-1]).replace('-', '')) elif s[1] == '@mwu': target.append(''.join(s[2:-1])) source.append(' '.join(s[2:-1])) elif s[1] == '@postag': target.append(''.join(s[-2])) source.append(''.join(s[-2])) elif s[1] == '@phantom': target.append(''.join(s[2])) source.append('') target2 = [] for t in target: if t[0] == '~': t = t.split('~') target2.append(t[1]) else: target2.append(t) sent = re.sub(r'\[.*?\]', 'EMPTY', x) word_c = 0 src = [] trg = [] for word in sent.split(): if word == 'EMPTY': src.append(source[word_c]) trg.append(target2[word_c]) word_c += 1 else: src.append(word) trg.append(word) source_text.append(' '.join(src)) target_text.append(' '.join(trg)) return source_text, target_text def split_sent(data, max_length): '''Splitting sentences if longer than given max_length value''' short_sent = [] long_sent = [] for n in data: n = n.split('|') if len(n[1]) <= max_length: short_sent.append(n[1]) elif len(n[1]) > max_length: n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '$$', n[1]) n[1] = n[1].replace("] [", "]##[") lines = textwrap.wrap(n[1], max_length, break_long_words=False) long_sent.append(lines) new_data = [] for s in long_sent: for s1 in s: s1 = s1.replace(']##[', '] [') s1 = s1.replace('$$', ' ') s2 = s1.split() if len(s2) > 2: new_data.append(s1) for x in short_sent: new_data.append(x) return new_data def preprocess_function(tk, s, t): '''tokenizing text and labels''' model_inputs = tk(s) with tk.as_target_tokenizer(): labels = tk(t) model_inputs["labels"] = labels["input_ids"] model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs def convert_tok(tok, sl): '''Convert tokenized object to Tensors and add padding''' input_ids = [] attention_mask = [] labels = [] decoder_attention_mask = [] for a, b, c, d in zip(tok['input_ids'], tok['attention_mask'], tok['labels'], tok['decoder_attention_mask']): input_ids.append(a) attention_mask.append(b) labels.append(c) decoder_attention_mask.append(d) input_ids_pad = pad_sequences(input_ids, padding='post', maxlen=sl) attention_mask_pad = pad_sequences(attention_mask, padding='post', maxlen=sl) labels_pad = pad_sequences(labels, padding='post', maxlen=sl) dec_attention_mask_pad = pad_sequences(decoder_attention_mask, padding='post', maxlen=sl) return {'input_ids': tf.constant(input_ids_pad), 'attention_mask': tf.constant(attention_mask_pad), 'labels': tf.constant(labels_pad), 'decoder_attention_mask': tf.constant(dec_attention_mask_pad)} def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev): '''Finetune and save a given T5 version with given parameters''' print('Training model: {}\nWith parameters:\nLearn rate: {}, ' 'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n' 'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep)) tk = AutoTokenizer.from_pretrained(model_name) args = create_arg_parser() source_train, target_train = create_data(train) source_test, target_test = create_data(dev) if args.custom_model: model = TFAutoModelForSeq2SeqLM.from_pretrained(args.custom_model, from_pt=True) else: model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) train_tok = preprocess_function(tk, source_train, target_train) dev_tok = preprocess_function(tk, source_test, target_test) tf_train = convert_tok(train_tok, sl_train) tf_dev = convert_tok(dev_tok, sl_dev) optim = AdamWeightDecay(learning_rate=lr) model.compile(optimizer=optim, loss=custom_loss, metrics=[accuracy]) ear_stop = tf.keras.callbacks.EarlyStopping(monitor=es, patience=es_p, restore_best_weights=True, mode="auto") model.fit(tf_train, validation_data=tf_dev, epochs=ep, batch_size=bs, callbacks=[ear_stop]) model.save_weights('{}_weights.h5'.format(model_name[7:])) return model def custom_loss(y_true, y_pred): '''Custom loss function''' loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') loss = loss_fn(y_true, y_pred) mask = tf.cast(y_true != 0, loss.dtype) loss *= mask return tf.reduce_sum(loss)/tf.reduce_sum(mask) def accuracy(y_true, y_pred): '''Custom accuracy function ''' y_pred = tf.argmax(y_pred, axis=-1) y_pred = tf.cast(y_pred, y_true.dtype) match = tf.cast(y_true == y_pred, tf.float32) mask = tf.cast(y_true != 0, tf.float32) return tf.reduce_sum(match)/tf.reduce_sum(mask) def main(): args = create_arg_parser() lr = args.learn_rate bs = args.batch_size sl_train = args.sequence_length_train sl_dev = args.sequence_length_dev split_length_train = (sl_train - 5) split_length_dev = (sl_dev - 5) ep = args.epochs if args.transformer == 'google/flan-t5-small': model_name = 'google/flan-t5-small' elif args.transformer == 'google/byt5-small': model_name = 'google/byt5-small' elif args.transformer == 'google/mt5-small': model_name = 'google/mt5-small' else: model_name = 'Unknown' early_stop = args.early_stop patience = args.early_stop_patience train_d = read_data(args.train_data) dev_d = read_data(args.dev_data) train_data = split_sent(train_d, split_length_train) dev_data = split_sent(dev_d, split_length_dev) print('Train size: {}\nDev size: {}\n'.format(len(train_data), len(dev_data))) print(train_model(model_name, lr, bs, sl_train, sl_dev, ep, early_stop, patience, train_data, dev_data)) if __name__ == '__main__': main()