File size: 4,146 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97) -> None:
super().__init__()
self.coef = coef
# make kernel
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
self.register_buffer(
"flipped_filter",
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
)
def forward(self, input: torch.tensor) -> torch.tensor:
assert (
len(input.size()) == 2
), "The number of dimensions of input tensor must be 2!"
# reflect padding to match lengths of in/out
input = input.unsqueeze(1)
input = F.pad(input, (1, 0), "reflect")
return F.conv1d(input, self.flipped_filter)
class AFMS(nn.Module):
"""
Alpha-Feature map scaling, added to the output of each residual block[1,2].
Reference:
[1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
[2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
"""
def __init__(self, nb_dim: int) -> None:
super().__init__()
self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
self.fc = nn.Linear(nb_dim, nb_dim)
self.sig = nn.Sigmoid()
def forward(self, x):
y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
x = x + self.alpha
x = x * y
return x
class Bottle2neck(nn.Module):
def __init__(
self,
inplanes,
planes,
kernel_size=None,
dilation=None,
scale=4,
pool=False,
):
super().__init__()
width = int(math.floor(planes / scale))
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
self.bn1 = nn.BatchNorm1d(width * scale)
self.nums = scale - 1
convs = []
bns = []
num_pad = math.floor(kernel_size / 2) * dilation
for i in range(self.nums):
convs.append(
nn.Conv1d(
width,
width,
kernel_size=kernel_size,
dilation=dilation,
padding=num_pad,
)
)
bns.append(nn.BatchNorm1d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
self.bn3 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU()
self.width = width
self.mp = nn.MaxPool1d(pool) if pool else False
self.afms = AFMS(planes)
if inplanes != planes: # if change in number of filters
self.residual = nn.Sequential(
nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
)
else:
self.residual = nn.Identity()
def forward(self, x):
residual = self.residual(x)
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(sp)
sp = self.bns[i](sp)
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = torch.cat((out, spx[self.nums]), 1)
out = self.conv3(out)
out = self.relu(out)
out = self.bn3(out)
out += residual
if self.mp:
out = self.mp(out)
out = self.afms(out)
return out
|