TripletGeoEncoder-demo / model_module.py
yeq6x's picture
refactoring
9d42859
raw
history blame contribute delete
343 Bytes
import pytorch_lightning as pl
from model import Autoencoder
class AutoencoderModule(pl.LightningModule):
def __init__(self, feature_dim=64):
super(AutoencoderModule, self).__init__()
self.feature_dim = feature_dim
self.model = Autoencoder(self.feature_dim)
def forward(self, x):
return self.model(x)