|
from transformers import PreTrainedModel, AutoConfig, AutoModel |
|
from .model import SincNet |
|
from .config import SincNetConfig |
|
|
|
class SincNetModel(PreTrainedModel): |
|
config_class = SincNetConfig |
|
base_model_prefix = "sincnet" |
|
|
|
def __init__(self, config: SincNetConfig): |
|
super().__init__(config) |
|
|
|
self.model = SincNet( |
|
sinc_filter_stride=config.stride, |
|
num_sinc_filters=config.num_sinc_filters, |
|
sinc_filter_length=config.sinc_filter_length, |
|
num_conv_filters=config.num_conv_filters, |
|
conv_filter_length=config.conv_filter_length, |
|
pool_kernel_size=config.pool_kernel_size, |
|
pool_stride=config.pool_stride, |
|
sample_rate=config.sample_rate, |
|
) |
|
|
|
def forward(self, waveforms): |
|
return self.model(waveforms) |
|
|
|
AutoConfig.register('sincnet', SincNetConfig) |
|
AutoModel.register(SincNetConfig, SincNetModel) |
|
|