File size: 2,121 Bytes
c5c5181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import yaml
from pathlib import Path
import click
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from models.mobilevit import MobileVIT
from data.data_preprocessing import FluorescentNeuronalDataModule
CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml"
DATA_DIR = "data/raw/"
LOGS_DIR = "reports/logs/FluorescentMobileVIT"
MODEL_DIR = "models/FluorescentMobileVIT"
# Define the accelerator
if torch.backends.mps.is_available():
DEVICE = torch.device("mps:0")
ACCELERATOR = "mps"
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
ACCELERATOR = "gpu"
else:
DEVICE = torch.device("cpu")
ACCELERATOR = "cpu"
@click.command()
@click.option(
"--data_dir",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=DATA_DIR,
)
@click.option(
"--config_file",
type=click.Path(exists=True, file_okay=True, path_type=Path),
default=CONFIG_FILE,
)
def train_model(data_dir, config_file):
# Load the best parameters
with open(config_file, "r") as file:
best_params = yaml.safe_load(file)
# Instantiate the model
model = MobileVIT(
learning_rate=best_params["learning_rate"],
weight_decay=best_params["weight_decay"],
)
# Define the callbacks of the model
model_checkpoint_cb = ModelCheckpoint(
save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss"
)
logger = TensorBoardLogger(save_dir=LOGS_DIR)
# Create the trainer with its parameters
trainer = pl.Trainer(
logger=logger,
devices=1,
accelerator=ACCELERATOR,
precision=16,
max_epochs=100,
log_every_n_steps=20,
callbacks=[model_checkpoint_cb],
)
data_module = FluorescentNeuronalDataModule(
data_dir=data_dir, batch_size=best_params["batch_size"]
)
trainer.fit(model=model, datamodule=data_module)
trainer.test(model=model, datamodule=data_module)
click.echo("\n\n==========The Training has Finished!==========")
|