|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Pooling functions to aggregate frame-level deep features |
|
into segment-level speaker embeddings |
|
|
|
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP, |
|
even though we remove the mean statistic, on Voxceleb. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class TAP(nn.Module): |
|
""" |
|
Temporal average pooling, only first-order mean is considered |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TAP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
pooling_mean = x.mean(dim=-1) |
|
|
|
pooling_mean = pooling_mean.flatten(start_dim=1) |
|
return pooling_mean |
|
|
|
def get_out_dim(self): |
|
self.out_dim = self.in_dim |
|
return self.out_dim |
|
|
|
|
|
class TSDP(nn.Module): |
|
""" |
|
Temporal standard deviation pooling, only second-order std is considered |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TSDP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
|
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) |
|
pooling_std = pooling_std.flatten(start_dim=1) |
|
return pooling_std |
|
|
|
def get_out_dim(self): |
|
self.out_dim = self.in_dim |
|
return self.out_dim |
|
|
|
|
|
class TSTP(nn.Module): |
|
""" |
|
Temporal statistics pooling, concatenate mean and std, which is used in |
|
x-vector |
|
Comment: simple concatenation can not make full use of both statistics |
|
""" |
|
|
|
def __init__(self, in_dim=0, **kwargs): |
|
super(TSTP, self).__init__() |
|
self.in_dim = in_dim |
|
|
|
def forward(self, x): |
|
|
|
pooling_mean = x.mean(dim=-1) |
|
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) |
|
pooling_mean = pooling_mean.flatten(start_dim=1) |
|
pooling_std = pooling_std.flatten(start_dim=1) |
|
stats = torch.cat((pooling_mean, pooling_std), 1) |
|
return stats |
|
|
|
def get_out_dim(self): |
|
self.out_dim = self.in_dim * 2 |
|
return self.out_dim |
|
|
|
|
|
class ASTP(nn.Module): |
|
"""Attentive statistics pooling: Channel- and context-dependent |
|
statistics pooling, first used in ECAPA_TDNN. |
|
""" |
|
|
|
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs): |
|
super(ASTP, self).__init__() |
|
self.in_dim = in_dim |
|
self.global_context_att = global_context_att |
|
|
|
|
|
|
|
if global_context_att: |
|
self.linear1 = nn.Conv1d( |
|
in_dim * 3, bottleneck_dim, kernel_size=1 |
|
) |
|
else: |
|
self.linear1 = nn.Conv1d( |
|
in_dim, bottleneck_dim, kernel_size=1 |
|
) |
|
self.linear2 = nn.Conv1d( |
|
bottleneck_dim, in_dim, kernel_size=1 |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) |
|
or a 4-dimensional tensor in resnet architecture (B,C,F,T) |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(x.shape) == 4: |
|
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) |
|
assert len(x.shape) == 3 |
|
|
|
if self.global_context_att: |
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) |
|
context_std = torch.sqrt( |
|
torch.var(x, dim=-1, keepdim=True) + 1e-7 |
|
).expand_as(x) |
|
x_in = torch.cat((x, context_mean, context_std), dim=1) |
|
else: |
|
x_in = x |
|
|
|
|
|
alpha = torch.tanh(self.linear1(x_in)) |
|
alpha = torch.softmax(self.linear2(alpha), dim=2) |
|
mean = torch.sum(alpha * x, dim=2) |
|
var = torch.sum(alpha * (x**2), dim=2) - mean**2 |
|
std = torch.sqrt(var.clamp(min=1e-7)) |
|
return torch.cat([mean, std], dim=1) |
|
|
|
def get_out_dim(self): |
|
self.out_dim = 2 * self.in_dim |
|
return self.out_dim |
|
|
|
|
|
class MHASTP(torch.nn.Module): |
|
"""Multi head attentive statistics pooling |
|
Reference: |
|
Self Multi-Head Attention for Speaker Recognition |
|
https://arxiv.org/pdf/1906.09890.pdf |
|
""" |
|
|
|
def __init__( |
|
self, in_dim, layer_num=2, head_num=2, d_s=1, bottleneck_dim=64, **kwargs |
|
): |
|
super(MHASTP, self).__init__() |
|
assert ( |
|
in_dim % head_num |
|
) == 0 |
|
self.in_dim = in_dim |
|
self.head_num = head_num |
|
d_model = int(in_dim / head_num) |
|
channel_dims = [bottleneck_dim for i in range(layer_num + 1)] |
|
if d_s > 1: |
|
d_s = d_model |
|
else: |
|
d_s = 1 |
|
self.d_s = d_s |
|
channel_dims[0], channel_dims[-1] = d_model, d_s |
|
heads_att_trans = [] |
|
for i in range(self.head_num): |
|
att_trans = nn.Sequential() |
|
for i in range(layer_num - 1): |
|
att_trans.add_module( |
|
"att_" + str(i), |
|
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1), |
|
) |
|
att_trans.add_module("tanh" + str(i), nn.Tanh()) |
|
att_trans.add_module( |
|
"att_" + str(layer_num - 1), |
|
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], 1, 1), |
|
) |
|
heads_att_trans.append(att_trans) |
|
self.heads_att_trans = nn.ModuleList(heads_att_trans) |
|
|
|
def forward(self, input): |
|
""" |
|
input: a 3-dimensional tensor in xvector architecture |
|
or a 4-dimensional tensor in resnet architecture |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(input.shape) == 4: |
|
input = input.reshape( |
|
input.shape[0], input.shape[1] * input.shape[2], input.shape[3] |
|
) |
|
assert len(input.shape) == 3 |
|
bs, f_dim, t_dim = input.shape |
|
chunks = torch.chunk(input, self.head_num, 1) |
|
|
|
chunks_out = [] |
|
|
|
|
|
for i, layer in enumerate(self.heads_att_trans): |
|
att_score = layer(chunks[i]) |
|
alpha = F.softmax(att_score, dim=-1) |
|
mean = torch.sum(alpha * chunks[i], dim=2) |
|
var = torch.sum(alpha * chunks[i] ** 2, dim=2) - mean**2 |
|
std = torch.sqrt(var.clamp(min=1e-7)) |
|
chunks_out.append(torch.cat((mean, std), dim=1)) |
|
out = torch.cat(chunks_out, dim=1) |
|
return out |
|
|
|
def get_out_dim(self): |
|
self.out_dim = 2 * self.in_dim |
|
return self.out_dim |
|
|
|
|
|
class MQMHASTP(torch.nn.Module): |
|
"""An attentive pooling |
|
Reference: |
|
multi query multi head attentive statistics pooling |
|
https://arxiv.org/pdf/2110.05042.pdf |
|
Args: |
|
in_dim: the feature dimension of input |
|
layer_num: the number of layer in the pooling layer |
|
query_num: the number of querys |
|
head_num: the number of heads |
|
bottleneck_dim: the bottleneck dimension |
|
|
|
SA (H = 1, Q = 1, n = 2, d_s = 1) ref: |
|
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf |
|
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: |
|
https://arxiv.org/pdf/1906.09890.pdf |
|
AS (H = 1, Q > 1, n = 2, d_s = 1) ref: |
|
https://arxiv.org/pdf/1803.10963.pdf |
|
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: |
|
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim, |
|
layer_num=2, |
|
query_num=2, |
|
head_num=8, |
|
d_s=2, |
|
bottleneck_dim=64, |
|
**kwargs |
|
): |
|
super(MQMHASTP, self).__init__() |
|
self.n_query = nn.ModuleList( |
|
[ |
|
MHASTP( |
|
in_dim, |
|
layer_num=layer_num, |
|
head_num=head_num, |
|
d_s=d_s, |
|
bottleneck_dim=bottleneck_dim, |
|
) |
|
for i in range(query_num) |
|
] |
|
) |
|
self.query_num = query_num |
|
self.in_dim = in_dim |
|
|
|
def forward(self, input): |
|
""" |
|
input: a 3-dimensional tensor in xvector architecture |
|
or a 4-dimensional tensor in resnet architecture |
|
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) |
|
""" |
|
if len(input.shape) == 4: |
|
input = input.reshape( |
|
input.shape[0], input.shape[1] * input.shape[2], input.shape[3] |
|
) |
|
assert len(input.shape) == 3 |
|
res = [] |
|
for i, layer in enumerate(self.n_query): |
|
res.append(layer(input)) |
|
out = torch.cat(res, dim=-1) |
|
return out |
|
|
|
def get_out_dim(self): |
|
self.out_dim = self.in_dim * 2 * self.query_num |
|
return self.out_dim |
|
|