autoencoder1d-AT-v1 / autoencoder.py
flavioschneider's picture
Upload AutoEncoder1d
57b6cde
from transformers import PreTrainedModel
from audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck
from .autoencoder_config import AutoEncoder1dConfig
bottleneck = { 'tanh': TanhBottleneck }
class AutoEncoder1d(PreTrainedModel):
config_class = AutoEncoder1dConfig
def __init__(self, config: AutoEncoder1dConfig):
super().__init__(config)
self.autoencoder = AE1d(
in_channels = config.in_channels,
patch_size = config.patch_size,
channels = config.channels,
multipliers = config.multipliers,
factors = config.factors,
num_blocks = config.num_blocks,
bottleneck = bottleneck[config.bottleneck]()
)
def forward(self, *args, **kwargs):
return self.autoencoder(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.autoencoder.encode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.autoencoder.decode(*args, **kwargs)