import pytorch_lightning as pl from model import Autoencoder class AutoencoderModule(pl.LightningModule): def __init__(self, feature_dim=64): super(AutoencoderModule, self).__init__() self.feature_dim = feature_dim self.model = Autoencoder(self.feature_dim) def forward(self, x): return self.model(x)