from transformers import PreTrainedModel from audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck from .autoencoder_config import AutoEncoder1dConfig bottleneck = { 'tanh': TanhBottleneck } class AutoEncoder1d(PreTrainedModel): config_class = AutoEncoder1dConfig def __init__(self, config: AutoEncoder1dConfig): super().__init__(config) self.autoencoder = AE1d( in_channels = config.in_channels, patch_size = config.patch_size, channels = config.channels, multipliers = config.multipliers, factors = config.factors, num_blocks = config.num_blocks, bottleneck = bottleneck[config.bottleneck]() ) def forward(self, *args, **kwargs): return self.autoencoder(*args, **kwargs) def encode(self, *args, **kwargs): return self.autoencoder.encode(*args, **kwargs) def decode(self, *args, **kwargs): return self.autoencoder.decode(*args, **kwargs)