File size: 343 Bytes
02ba63a
 
 
 
9d42859
02ba63a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
import pytorch_lightning as pl
from model import Autoencoder

class AutoencoderModule(pl.LightningModule):
    def __init__(self, feature_dim=64):
        super(AutoencoderModule, self).__init__()
        self.feature_dim = feature_dim
        self.model = Autoencoder(self.feature_dim)

    def forward(self, x):
        return self.model(x)