madhavanvenkatesh
commited on
fixed bug related to dynamic ranges in dictionary with 'min' and 'max' value mismatch in optuna suggest fn
Browse files- geneformer/mtl/train.py +9 -14
geneformer/mtl/train.py
CHANGED
@@ -9,7 +9,7 @@ from tqdm import tqdm
|
|
9 |
|
10 |
from .imports import *
|
11 |
from .model import GeneformerMultiTask
|
12 |
-
from .utils import calculate_task_specific_metrics
|
13 |
|
14 |
|
15 |
def set_seed(seed):
|
@@ -280,7 +280,7 @@ def objective(
|
|
280 |
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
)
|
282 |
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
-
"use_attention_pooling", [
|
284 |
)
|
285 |
|
286 |
if config["use_task_weights"]:
|
@@ -299,18 +299,13 @@ def objective(
|
|
299 |
else:
|
300 |
config["task_weights"] = None
|
301 |
|
302 |
-
#
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
elif isinstance(config["max_layers_to_freeze"], int):
|
310 |
-
# If it's already an int, we don't need to suggest it
|
311 |
-
pass
|
312 |
-
else:
|
313 |
-
raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.")
|
314 |
|
315 |
model = create_model(config, num_labels_list, device)
|
316 |
total_steps = len(train_loader) * config["epochs"]
|
|
|
9 |
|
10 |
from .imports import *
|
11 |
from .model import GeneformerMultiTask
|
12 |
+
from .utils import calculate_task_specific_metrics, get_layer_freeze_range
|
13 |
|
14 |
|
15 |
def set_seed(seed):
|
|
|
280 |
"lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
|
281 |
)
|
282 |
config["use_attention_pooling"] = trial.suggest_categorical(
|
283 |
+
"use_attention_pooling", [False]
|
284 |
)
|
285 |
|
286 |
if config["use_task_weights"]:
|
|
|
299 |
else:
|
300 |
config["task_weights"] = None
|
301 |
|
302 |
+
# Dynamic range for max_layers_to_freeze
|
303 |
+
freeze_range = get_layer_freeze_range(config["pretrained_path"])
|
304 |
+
config["max_layers_to_freeze"] = trial.suggest_int(
|
305 |
+
"max_layers_to_freeze",
|
306 |
+
freeze_range["min"],
|
307 |
+
freeze_range["max"]
|
308 |
+
)
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
model = create_model(config, num_labels_list, device)
|
311 |
total_steps = len(train_loader) * config["epochs"]
|