icefire080 commited on
Commit
502cc48
1 Parent(s): 5a43832

a tiny bug fix missing default_training_args

Browse files

hello , there is a tiny bug in function "get_default_train_args" when num_layers != 6. A default_training_args which is a empy dict is missing.

Files changed (1) hide show
  1. geneformer/classifier_utils.py +2 -0
geneformer/classifier_utils.py CHANGED
@@ -387,6 +387,8 @@ def get_default_train_args(model, classifier, data, output_dir):
387
  "per_device_train_batch_size": batch_size,
388
  "per_device_eval_batch_size": batch_size,
389
  }
 
 
390
 
391
  training_args = {
392
  "num_train_epochs": epochs,
 
387
  "per_device_train_batch_size": batch_size,
388
  "per_device_eval_batch_size": batch_size,
389
  }
390
+ else:
391
+ default_training_args = {}
392
 
393
  training_args = {
394
  "num_train_epochs": epochs,