swim / main.py
qninhdt
cc
020afa7
import torch
from omegaconf import OmegaConf
from swim.utils import instantiate_from_config
from torchinfo import summary
from swim.modules.dataset import SwimDataModule
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
torch.set_float32_matmul_precision("medium")
config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
model = instantiate_from_config(config.model)
model.learning_rate = config.model.base_learning_rate
datamodule = SwimDataModule(
root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512
)
logger = WandbLogger(project="swim", name="autoencoder_kl")
trainer = Trainer(max_epochs=10, devices=[0], logger=logger, log_every_n_steps=10)
trainer.fit(model, datamodule)