DINO-HuVITS / src /wespeaker_campplus /fbank_feature_extractor.py
SazerLife's picture
feat: added model
36a67ca
import torch.nn as nn
import torch
import torch.nn.functional as F
import torchaudio
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
self.register_buffer(
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, input: torch.tensor) -> torch.tensor:
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), 'reflect')
return F.conv1d(input, self.flipped_filter).squeeze(1)
class FbankFeatureExtractor(nn.Module):
"""Some Information about MyModule"""
def __init__(self, feat_dim = 80, f_max = 7600, **kwargs):
super(FbankFeatureExtractor, self, ).__init__()
self.torchfbank = torch.nn.Sequential(
PreEmphasis(),
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, \
f_min = 20, f_max = f_max, window_fn=torch.hamming_window, n_mels=feat_dim),
)
self.instance_norm = nn.InstanceNorm1d(feat_dim)
def forward(self, x):
with torch.no_grad():
x = self.torchfbank(x)+1e-6
x = x.log()
x = x - torch.mean(x, dim=-1, keepdim=True)
return x