File size: 5,453 Bytes
47990ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import random

from .data import get_data_loader, preload_and_process_data
from .imports import *
from .model import GeneformerMultiTask
from .train import objective, train_model
from .utils import save_model


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run_manual_tuning(config):
    # Set seed for reproducibility
    set_seed(config["seed"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    (
        train_dataset,
        train_cell_id_mapping,
        val_dataset,
        val_cell_id_mapping,
        num_labels_list,
    ) = preload_and_process_data(config)
    train_loader = get_data_loader(train_dataset, config["batch_size"])
    val_loader = get_data_loader(val_dataset, config["batch_size"])

    # Print the manual hyperparameters being used
    print("\nManual hyperparameters being used:")
    for key, value in config["manual_hyperparameters"].items():
        print(f"{key}: {value}")
    print()  # Add an empty line for better readability

    # Use the manual hyperparameters
    for key, value in config["manual_hyperparameters"].items():
        config[key] = value

    # Train the model
    val_loss, trained_model = train_model(
        config,
        device,
        train_loader,
        val_loader,
        train_cell_id_mapping,
        val_cell_id_mapping,
        num_labels_list,
    )

    print(f"\nValidation loss with manual hyperparameters: {val_loss}")

    # Save the trained model
    model_save_directory = os.path.join(
        config["model_save_path"], "GeneformerMultiTask"
    )
    save_model(trained_model, model_save_directory)

    # Save the hyperparameters
    hyperparams_to_save = {
        **config["manual_hyperparameters"],
        "dropout_rate": config["dropout_rate"],
        "use_task_weights": config["use_task_weights"],
        "task_weights": config["task_weights"],
        "max_layers_to_freeze": config["max_layers_to_freeze"],
        "use_attention_pooling": config["use_attention_pooling"],
    }
    hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
    with open(hyperparams_path, "w") as f:
        json.dump(hyperparams_to_save, f)
    print(f"Manual hyperparameters saved to {hyperparams_path}")

    return val_loss


def run_optuna_study(config):
    # Set seed for reproducibility
    set_seed(config["seed"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    (
        train_dataset,
        train_cell_id_mapping,
        val_dataset,
        val_cell_id_mapping,
        num_labels_list,
    ) = preload_and_process_data(config)
    train_loader = get_data_loader(train_dataset, config["batch_size"])
    val_loader = get_data_loader(val_dataset, config["batch_size"])

    if config["use_manual_hyperparameters"]:
        train_model(
            config,
            device,
            train_loader,
            val_loader,
            train_cell_id_mapping,
            val_cell_id_mapping,
            num_labels_list,
        )
    else:
        objective_with_config_and_data = functools.partial(
            objective,
            train_loader=train_loader,
            val_loader=val_loader,
            train_cell_id_mapping=train_cell_id_mapping,
            val_cell_id_mapping=val_cell_id_mapping,
            num_labels_list=num_labels_list,
            config=config,
            device=device,
        )

        study = optuna.create_study(
            direction="minimize",  # Minimize validation loss
            study_name=config["study_name"],
            # storage=config["storage"],
            load_if_exists=True,
        )

        study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])

        # After finding the best trial
        best_params = study.best_trial.params
        best_task_weights = study.best_trial.user_attrs["task_weights"]
        print("Saving the best model and its hyperparameters...")

        # Saving model as before
        best_model = GeneformerMultiTask(
            config["pretrained_path"],
            num_labels_list,
            dropout_rate=best_params["dropout_rate"],
            use_task_weights=config["use_task_weights"],
            task_weights=best_task_weights,
        )

        # Get the best model state dictionary
        best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]

        # Remove the "module." prefix from the state dictionary keys if present
        best_model_state_dict = {
            k.replace("module.", ""): v for k, v in best_model_state_dict.items()
        }

        # Load the modified state dictionary into the model, skipping unexpected keys
        best_model.load_state_dict(best_model_state_dict, strict=False)

        model_save_directory = os.path.join(
            config["model_save_path"], "GeneformerMultiTask"
        )
        save_model(best_model, model_save_directory)

        # Additionally, save the best hyperparameters and task weights
        hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")

        with open(hyperparams_path, "w") as f:
            json.dump({**best_params, "task_weights": best_task_weights}, f)
        print(f"Best hyperparameters and task weights saved to {hyperparams_path}")