|
import torch |
|
from omegaconf import OmegaConf |
|
from swim.utils import instantiate_from_config |
|
from torchinfo import summary |
|
from swim.modules.dataset import SwimDataModule |
|
from lightning import Trainer |
|
from lightning.pytorch.loggers import WandbLogger |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml") |
|
|
|
model = instantiate_from_config(config.model) |
|
model.learning_rate = config.model.base_learning_rate |
|
|
|
datamodule = SwimDataModule( |
|
root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512 |
|
) |
|
|
|
logger = WandbLogger(project="swim", name="autoencoder_kl") |
|
|
|
trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10) |
|
trainer.fit(model, datamodule) |
|
|