|
from transformers import PreTrainedModel |
|
from audio_encoders_pytorch import AutoEncoder1d as AE1d, TanhBottleneck |
|
from .autoencoder_config import AutoEncoder1dConfig |
|
|
|
bottleneck = { 'tanh': TanhBottleneck } |
|
|
|
class AutoEncoder1d(PreTrainedModel): |
|
|
|
config_class = AutoEncoder1dConfig |
|
|
|
def __init__(self, config: AutoEncoder1dConfig): |
|
super().__init__(config) |
|
|
|
self.autoencoder = AE1d( |
|
in_channels = config.in_channels, |
|
patch_size = config.patch_size, |
|
channels = config.channels, |
|
multipliers = config.multipliers, |
|
factors = config.factors, |
|
num_blocks = config.num_blocks, |
|
bottleneck = bottleneck[config.bottleneck]() |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.autoencoder(*args, **kwargs) |
|
|
|
def encode(self, *args, **kwargs): |
|
return self.autoencoder.encode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.autoencoder.decode(*args, **kwargs) |
|
|