update
Browse files- T5Trainer.py +16 -18
T5Trainer.py
CHANGED
@@ -27,33 +27,31 @@ def create_arg_parser():
|
|
27 |
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
|
28 |
type=str, help="this argument takes the pretrained "
|
29 |
"language model URL from HuggingFace "
|
30 |
-
"default is
|
31 |
"HuggingFace for full URL")
|
32 |
parser.add_argument("-c_model", "--custom_model",
|
33 |
type=str, help="this argument takes a custom "
|
34 |
"pretrained checkpoint")
|
35 |
-
parser.add_argument("-train", "--train_data", default='
|
36 |
type=str, help="this argument takes the train "
|
37 |
"data file as input")
|
38 |
-
parser.add_argument("-dev", "--dev_data", default='
|
39 |
-
help="this argument takes the dev data file
|
40 |
-
|
41 |
-
parser.add_argument("-
|
42 |
-
help="class weights for custom loss calculation")
|
43 |
-
parser.add_argument("-lr", "--learn_rate", default=1e-3, type=float,
|
44 |
help="Set a custom learn rate for "
|
45 |
-
"the
|
46 |
-
parser.add_argument("-bs", "--batch_size", default=
|
47 |
help="Set a custom batch size for "
|
48 |
"the pretrained language model, default is 8")
|
49 |
-
parser.add_argument("-sl_train", "--sequence_length_train", default=
|
50 |
type=int, help="Set a custom maximum sequence length"
|
51 |
"for the pretrained language model,"
|
52 |
-
"default is
|
53 |
-
parser.add_argument("-sl_dev", "--sequence_length_dev", default=
|
54 |
type=int, help="Set a custom maximum sequence length"
|
55 |
"for the pretrained language model,"
|
56 |
-
"default is
|
57 |
parser.add_argument("-ep", "--epochs", default=1, type=int,
|
58 |
help="This argument selects the amount of epochs "
|
59 |
"to run the model with, default is 1 epoch")
|
@@ -61,7 +59,7 @@ def create_arg_parser():
|
|
61 |
help="Set the value to monitor for earlystopping")
|
62 |
parser.add_argument("-es_p", "--early_stop_patience", default=2,
|
63 |
type=int, help="Set the patience value for "
|
64 |
-
"earlystopping")
|
65 |
args = parser.parse_args()
|
66 |
return args
|
67 |
|
@@ -131,7 +129,7 @@ def create_data(data):
|
|
131 |
|
132 |
|
133 |
def split_sent(data, max_length):
|
134 |
-
'''Splitting sentences if longer than given
|
135 |
short_sent = []
|
136 |
long_sent = []
|
137 |
for n in data:
|
@@ -159,7 +157,7 @@ def split_sent(data, max_length):
|
|
159 |
|
160 |
|
161 |
def preprocess_function(tk, s, t):
|
162 |
-
'''tokenizing
|
163 |
model_inputs = tk(s)
|
164 |
|
165 |
with tk.as_target_tokenizer():
|
@@ -195,7 +193,7 @@ def convert_tok(tok, sl):
|
|
195 |
|
196 |
|
197 |
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
|
198 |
-
'''Finetune and save a given T5 version with given
|
199 |
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
|
200 |
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
|
201 |
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
|
|
|
27 |
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
|
28 |
type=str, help="this argument takes the pretrained "
|
29 |
"language model URL from HuggingFace "
|
30 |
+
"default is ByT5-small, please visit "
|
31 |
"HuggingFace for full URL")
|
32 |
parser.add_argument("-c_model", "--custom_model",
|
33 |
type=str, help="this argument takes a custom "
|
34 |
"pretrained checkpoint")
|
35 |
+
parser.add_argument("-train", "--train_data", default='training_data10k.txt',
|
36 |
type=str, help="this argument takes the train "
|
37 |
"data file as input")
|
38 |
+
parser.add_argument("-dev", "--dev_data", default='validation_data.txt',
|
39 |
+
type=str, help="this argument takes the dev data file "
|
40 |
+
"as input")
|
41 |
+
parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float,
|
|
|
|
|
42 |
help="Set a custom learn rate for "
|
43 |
+
"the model, default is 5e-5")
|
44 |
+
parser.add_argument("-bs", "--batch_size", default=8, type=int,
|
45 |
help="Set a custom batch size for "
|
46 |
"the pretrained language model, default is 8")
|
47 |
+
parser.add_argument("-sl_train", "--sequence_length_train", default=155,
|
48 |
type=int, help="Set a custom maximum sequence length"
|
49 |
"for the pretrained language model,"
|
50 |
+
"default is 155")
|
51 |
+
parser.add_argument("-sl_dev", "--sequence_length_dev", default=155,
|
52 |
type=int, help="Set a custom maximum sequence length"
|
53 |
"for the pretrained language model,"
|
54 |
+
"default is 155")
|
55 |
parser.add_argument("-ep", "--epochs", default=1, type=int,
|
56 |
help="This argument selects the amount of epochs "
|
57 |
"to run the model with, default is 1 epoch")
|
|
|
59 |
help="Set the value to monitor for earlystopping")
|
60 |
parser.add_argument("-es_p", "--early_stop_patience", default=2,
|
61 |
type=int, help="Set the patience value for "
|
62 |
+
"earlystopping, default is 2")
|
63 |
args = parser.parse_args()
|
64 |
return args
|
65 |
|
|
|
129 |
|
130 |
|
131 |
def split_sent(data, max_length):
|
132 |
+
'''Splitting sentences if longer than given max_length value'''
|
133 |
short_sent = []
|
134 |
long_sent = []
|
135 |
for n in data:
|
|
|
157 |
|
158 |
|
159 |
def preprocess_function(tk, s, t):
|
160 |
+
'''tokenizing text and labels'''
|
161 |
model_inputs = tk(s)
|
162 |
|
163 |
with tk.as_target_tokenizer():
|
|
|
193 |
|
194 |
|
195 |
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
|
196 |
+
'''Finetune and save a given T5 version with given parameters'''
|
197 |
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
|
198 |
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
|
199 |
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
|