AWolters commited on
Commit
c44944d
1 Parent(s): 4403157
Files changed (1) hide show
  1. T5Trainer.py +16 -18
T5Trainer.py CHANGED
@@ -27,33 +27,31 @@ def create_arg_parser():
27
  parser.add_argument("-tf", "--transformer", default="google/byt5-small",
28
  type=str, help="this argument takes the pretrained "
29
  "language model URL from HuggingFace "
30
- "default is HateBERT, please visit "
31
  "HuggingFace for full URL")
32
  parser.add_argument("-c_model", "--custom_model",
33
  type=str, help="this argument takes a custom "
34
  "pretrained checkpoint")
35
- parser.add_argument("-train", "--train_data", default='training_data.txt',
36
  type=str, help="this argument takes the train "
37
  "data file as input")
38
- parser.add_argument("-dev", "--dev_data", default='dev_data.txt', type=str,
39
- help="this argument takes the dev data file as "
40
- "input")
41
- parser.add_argument("-sample_weight", "--sample_weight", type=str,
42
- help="class weights for custom loss calculation")
43
- parser.add_argument("-lr", "--learn_rate", default=1e-3, type=float,
44
  help="Set a custom learn rate for "
45
- "the pretrained language model, default is 5e-5")
46
- parser.add_argument("-bs", "--batch_size", default=16, type=int,
47
  help="Set a custom batch size for "
48
  "the pretrained language model, default is 8")
49
- parser.add_argument("-sl_train", "--sequence_length_train", default=100,
50
  type=int, help="Set a custom maximum sequence length"
51
  "for the pretrained language model,"
52
- "default is 100")
53
- parser.add_argument("-sl_dev", "--sequence_length_dev", default=100,
54
  type=int, help="Set a custom maximum sequence length"
55
  "for the pretrained language model,"
56
- "default is 100")
57
  parser.add_argument("-ep", "--epochs", default=1, type=int,
58
  help="This argument selects the amount of epochs "
59
  "to run the model with, default is 1 epoch")
@@ -61,7 +59,7 @@ def create_arg_parser():
61
  help="Set the value to monitor for earlystopping")
62
  parser.add_argument("-es_p", "--early_stop_patience", default=2,
63
  type=int, help="Set the patience value for "
64
- "earlystopping")
65
  args = parser.parse_args()
66
  return args
67
 
@@ -131,7 +129,7 @@ def create_data(data):
131
 
132
 
133
  def split_sent(data, max_length):
134
- '''Splitting sentences if longer than given threshold'''
135
  short_sent = []
136
  long_sent = []
137
  for n in data:
@@ -159,7 +157,7 @@ def split_sent(data, max_length):
159
 
160
 
161
  def preprocess_function(tk, s, t):
162
- '''tokenizing data text and labels'''
163
  model_inputs = tk(s)
164
 
165
  with tk.as_target_tokenizer():
@@ -195,7 +193,7 @@ def convert_tok(tok, sl):
195
 
196
 
197
  def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
198
- '''Finetune and save a given T5 version with given (hyper)parameters'''
199
  print('Training model: {}\nWith parameters:\nLearn rate: {}, '
200
  'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
201
  'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
 
27
  parser.add_argument("-tf", "--transformer", default="google/byt5-small",
28
  type=str, help="this argument takes the pretrained "
29
  "language model URL from HuggingFace "
30
+ "default is ByT5-small, please visit "
31
  "HuggingFace for full URL")
32
  parser.add_argument("-c_model", "--custom_model",
33
  type=str, help="this argument takes a custom "
34
  "pretrained checkpoint")
35
+ parser.add_argument("-train", "--train_data", default='training_data10k.txt',
36
  type=str, help="this argument takes the train "
37
  "data file as input")
38
+ parser.add_argument("-dev", "--dev_data", default='validation_data.txt',
39
+ type=str, help="this argument takes the dev data file "
40
+ "as input")
41
+ parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float,
 
 
42
  help="Set a custom learn rate for "
43
+ "the model, default is 5e-5")
44
+ parser.add_argument("-bs", "--batch_size", default=8, type=int,
45
  help="Set a custom batch size for "
46
  "the pretrained language model, default is 8")
47
+ parser.add_argument("-sl_train", "--sequence_length_train", default=155,
48
  type=int, help="Set a custom maximum sequence length"
49
  "for the pretrained language model,"
50
+ "default is 155")
51
+ parser.add_argument("-sl_dev", "--sequence_length_dev", default=155,
52
  type=int, help="Set a custom maximum sequence length"
53
  "for the pretrained language model,"
54
+ "default is 155")
55
  parser.add_argument("-ep", "--epochs", default=1, type=int,
56
  help="This argument selects the amount of epochs "
57
  "to run the model with, default is 1 epoch")
 
59
  help="Set the value to monitor for earlystopping")
60
  parser.add_argument("-es_p", "--early_stop_patience", default=2,
61
  type=int, help="Set the patience value for "
62
+ "earlystopping, default is 2")
63
  args = parser.parse_args()
64
  return args
65
 
 
129
 
130
 
131
  def split_sent(data, max_length):
132
+ '''Splitting sentences if longer than given max_length value'''
133
  short_sent = []
134
  long_sent = []
135
  for n in data:
 
157
 
158
 
159
  def preprocess_function(tk, s, t):
160
+ '''tokenizing text and labels'''
161
  model_inputs = tk(s)
162
 
163
  with tk.as_target_tokenizer():
 
193
 
194
 
195
  def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
196
+ '''Finetune and save a given T5 version with given parameters'''
197
  print('Training model: {}\nWith parameters:\nLearn rate: {}, '
198
  'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
199
  'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))