|
import os |
|
from pathlib import Path |
|
import yaml |
|
import torch |
|
import optuna |
|
import pytorch_lightning as pl |
|
import click |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
|
|
from models.mobilevit import MobileVIT |
|
from data.data_preprocessing import FluorescentNeuronalDataModule |
|
|
|
|
|
MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small" |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
DEVICE = torch.device("mps:0") |
|
ACCELERATOR = "mps" |
|
elif torch.cuda.is_available(): |
|
DEVICE = torch.device("cuda") |
|
ACCELERATOR = "gpu" |
|
else: |
|
DEVICE = torch.device("cpu") |
|
ACCELERATOR = "cpu" |
|
|
|
RAW_DATA_PATH = "./data/raw/" |
|
DEFAULT_CONFIG_FILE = "./config/fluorescent_mobilevit_hps.yaml" |
|
|
|
CLASSES = {0: "Background", 1: "Neuron"} |
|
|
|
IMG_SIZE = [256, 256] |
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--data_dir", |
|
type=click.Path(exists=True, file_okay=True, path_type=Path), |
|
default=RAW_DATA_PATH, |
|
) |
|
@click.option( |
|
"--config_file", |
|
type=click.Path(exists=True, file_okay=True, path_type=Path), |
|
default=DEFAULT_CONFIG_FILE, |
|
) |
|
@click.option("--dataset_size", type=click.FLOAT, default=0.25) |
|
@click.option("--force-tune/--no-force-tune", default=False) |
|
def get_best_params(data_dir, config_file, dataset_size, force_tune) -> dict: |
|
def objective(trial: optuna.Trial, dataset_size=dataset_size) -> float: |
|
|
|
learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True) |
|
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) |
|
batch_size = trial.suggest_int("batch_size", 2, 4, log=True) |
|
|
|
|
|
early_stopping_cb = EarlyStopping(monitor="val_loss", patience=2) |
|
|
|
|
|
model = MobileVIT(learning_rate=learning_rate, weight_decay=weight_decay) |
|
|
|
|
|
data_module = FluorescentNeuronalDataModule( |
|
batch_size=batch_size, dataset_size=dataset_size, data_dir=data_dir |
|
) |
|
data_module.setup() |
|
|
|
|
|
trainer = pl.Trainer( |
|
devices=1, |
|
accelerator=ACCELERATOR, |
|
precision="16-mixed", |
|
max_epochs=5, |
|
log_every_n_steps=5, |
|
callbacks=[early_stopping_cb], |
|
) |
|
trainer.fit( |
|
model, |
|
train_dataloaders=data_module.train_dataloader(), |
|
val_dataloaders=data_module.val_dataloader(), |
|
) |
|
return trainer.callback_metrics["val_loss"].item() |
|
|
|
if os.path.exists(config_file) and force_tune: |
|
os.remove(config_file) |
|
pruner = optuna.pruners.MedianPruner() |
|
study = optuna.create_study(direction="maximize", pruner=pruner) |
|
|
|
study.optimize(objective, n_trials=25) |
|
best_params = study.best_params |
|
with open(config_file, "w") as file: |
|
yaml.dump(best_params, file) |
|
elif os.path.exists(config_file): |
|
with open(config_file, "r") as file: |
|
best_params = yaml.safe_load(file) |
|
else: |
|
pruner = optuna.pruners.MedianPruner() |
|
study = optuna.create_study(direction="minimize", pruner=pruner) |
|
|
|
study.optimize(objective, n_trials=25) |
|
best_params = study.best_params |
|
with open(config_file, "w") as file: |
|
yaml.dump(best_params, file) |
|
|
|
click.echo(f"The best parameters are:\n{best_params}") |
|
|