|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class PreEmphasis(torch.nn.Module):
|
|
def __init__(self, coef: float = 0.97) -> None:
|
|
super().__init__()
|
|
self.coef = coef
|
|
|
|
|
|
self.register_buffer(
|
|
"flipped_filter",
|
|
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
|
|
)
|
|
|
|
def forward(self, input: torch.tensor) -> torch.tensor:
|
|
assert (
|
|
len(input.size()) == 2
|
|
), "The number of dimensions of input tensor must be 2!"
|
|
|
|
input = input.unsqueeze(1)
|
|
input = F.pad(input, (1, 0), "reflect")
|
|
return F.conv1d(input, self.flipped_filter)
|
|
|
|
|
|
class AFMS(nn.Module):
|
|
"""
|
|
Alpha-Feature map scaling, added to the output of each residual block[1,2].
|
|
|
|
Reference:
|
|
[1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
|
|
[2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
|
|
"""
|
|
|
|
def __init__(self, nb_dim: int) -> None:
|
|
super().__init__()
|
|
self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
|
|
self.fc = nn.Linear(nb_dim, nb_dim)
|
|
self.sig = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
|
|
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
|
|
|
|
x = x + self.alpha
|
|
x = x * y
|
|
return x
|
|
|
|
|
|
class Bottle2neck(nn.Module):
|
|
def __init__(
|
|
self,
|
|
inplanes,
|
|
planes,
|
|
kernel_size=None,
|
|
dilation=None,
|
|
scale=4,
|
|
pool=False,
|
|
):
|
|
super().__init__()
|
|
|
|
width = int(math.floor(planes / scale))
|
|
|
|
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
|
self.bn1 = nn.BatchNorm1d(width * scale)
|
|
|
|
self.nums = scale - 1
|
|
|
|
convs = []
|
|
bns = []
|
|
|
|
num_pad = math.floor(kernel_size / 2) * dilation
|
|
|
|
for i in range(self.nums):
|
|
convs.append(
|
|
nn.Conv1d(
|
|
width,
|
|
width,
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
padding=num_pad,
|
|
)
|
|
)
|
|
bns.append(nn.BatchNorm1d(width))
|
|
|
|
self.convs = nn.ModuleList(convs)
|
|
self.bns = nn.ModuleList(bns)
|
|
|
|
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
|
self.bn3 = nn.BatchNorm1d(planes)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.width = width
|
|
|
|
self.mp = nn.MaxPool1d(pool) if pool else False
|
|
self.afms = AFMS(planes)
|
|
|
|
if inplanes != planes:
|
|
self.residual = nn.Sequential(
|
|
nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
|
|
)
|
|
else:
|
|
self.residual = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
residual = self.residual(x)
|
|
|
|
out = self.conv1(x)
|
|
out = self.relu(out)
|
|
out = self.bn1(out)
|
|
|
|
spx = torch.split(out, self.width, 1)
|
|
for i in range(self.nums):
|
|
if i == 0:
|
|
sp = spx[i]
|
|
else:
|
|
sp = sp + spx[i]
|
|
sp = self.convs[i](sp)
|
|
sp = self.relu(sp)
|
|
sp = self.bns[i](sp)
|
|
if i == 0:
|
|
out = sp
|
|
else:
|
|
out = torch.cat((out, sp), 1)
|
|
|
|
out = torch.cat((out, spx[self.nums]), 1)
|
|
|
|
out = self.conv3(out)
|
|
out = self.relu(out)
|
|
out = self.bn3(out)
|
|
|
|
out += residual
|
|
if self.mp:
|
|
out = self.mp(out)
|
|
out = self.afms(out)
|
|
|
|
return out
|
|
|