Spaces:
Running
on
Zero
Running
on
Zero
from typing import Optional, Tuple | |
from wenet.transformer.attention import MultiHeadedAttention | |
from torch import nn | |
import math | |
import torch | |
class MultiHeadedAttentionSANM(MultiHeadedAttention): | |
"""Multi-Head Attention layer. | |
Args: | |
n_head (int): The number of heads. | |
n_feat (int): The number of features. | |
dropout_rate (float): Dropout rate. | |
""" | |
def __init__(self, | |
n_head, | |
in_feat, | |
n_feat, | |
dropout_rate, | |
kernel_size, | |
sanm_shfit=0): | |
"""Construct an MultiHeadedAttention object.""" | |
super().__init__(n_head, n_feat, dropout_rate) | |
# We assume d_v always equals d_k | |
# self.linear_q = nn.Linear(n_feat, n_feat) | |
# self.linear_k = nn.Linear(n_feat, n_feat) | |
# self.linear_v = nn.Linear(n_feat, n_feat) | |
del self.linear_q, self.linear_k, self.linear_v | |
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) | |
self.fsmn_block = nn.Conv1d(n_feat, | |
n_feat, | |
kernel_size, | |
stride=1, | |
padding=0, | |
groups=n_feat, | |
bias=False) | |
# padding | |
self.left_padding = (kernel_size - 1) // 2 | |
if sanm_shfit > 0: | |
self.left_padding = self.left_padding + sanm_shfit | |
self.right_padding = kernel_size - 1 - self.left_padding | |
self.pad_fn = nn.ConstantPad1d((self.left_padding, self.right_padding), | |
0.0) | |
def forward_qkv( | |
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
x = query | |
b, t, _ = x.size() | |
q_k_v = self.linear_q_k_v(x) | |
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) | |
q = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time1, d_k) | |
k = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time2, d_k) | |
v = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time2, d_k) | |
return q, k, v | |
def forward_fsmn(self, | |
inputs: torch.Tensor, | |
mask: torch.Tensor, | |
mask_shfit_chunk: Optional[torch.Tensor] = None): | |
b, _, t, _ = inputs.size() | |
inputs = inputs.transpose(1, 2).view(b, t, -1) | |
if mask.size(2) > 0: # time2 > 0 | |
# TODO(Mddct): make sure mask is right | |
if mask_shfit_chunk is not None: | |
mask = mask * mask_shfit_chunk | |
mask = mask.transpose(1, 2) # [B,T,1] | |
inputs = inputs * mask | |
x = inputs.transpose(1, 2) | |
# x = torch.nn.functional.pad(x, (self.left_padding, self.right_padding), | |
# value=0.0, | |
# mode='constant') | |
x = self.pad_fn(x) | |
x = self.fsmn_block(x) | |
x = x.transpose(1, 2) | |
x += inputs | |
x = self.dropout(x) | |
return x * mask | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
pos_emb: torch.Tensor = torch.empty(0), | |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), | |
mask_shfit_chunk: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
q, k, v = self.forward_qkv(query, key, value) | |
if cache.size(0) > 0: | |
key_cache, value_cache = torch.split(cache, | |
cache.size(-1) // 2, | |
dim=-1) | |
k = torch.cat([key_cache, k], dim=2) | |
v = torch.cat([value_cache, v], dim=2) | |
# NOTE(Mddct): we need know fsmn_memory's cache, but paraformer is nonstreamming | |
# refactor later if streaming model is available | |
new_cache = torch.cat((k, v), dim=-1) | |
fsmn_memory = self.forward_fsmn(v, | |
mask=mask_pad, | |
mask_shfit_chunk=mask_shfit_chunk) | |
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | |
att = self.forward_attention(v, scores, mask) | |
return att + fsmn_memory, new_cache | |
class DummyMultiHeadSANM(MultiHeadedAttentionSANM): | |
"""A dummy multihead attention for Paraformer befroe cross attention | |
""" | |
def __init__(self, | |
n_head, | |
in_feat, | |
n_feat, | |
dropout_rate, | |
kernel_size, | |
sanm_shfit=0): | |
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size, | |
sanm_shfit) | |
del self.linear_q_k_v | |
del self.linear_out | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
pos_emb: torch.Tensor = torch.empty(0), | |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), | |
mask_shfit_chunk: Optional[torch.Tensor] = None | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
query = query * mask_pad.transpose(1, 2) | |
inputs = query | |
x = inputs.transpose(1, 2) | |
x = self.pad_fn(x) | |
# TODO(Mddct): cache here for future streaming | |
cache: Optional[torch.Tensor] = None | |
x = self.fsmn_block(x) | |
x = x.transpose(1, 2) | |
if x.size(1) != inputs.size(1): | |
inputs = inputs[:, -1, :] | |
x = x + inputs | |
x = self.dropout(x) | |
x = x * mask_pad.transpose(1, 2) | |
return x, cache | |
class MultiHeadAttentionCross(MultiHeadedAttentionSANM): | |
def __init__(self, | |
n_head, | |
in_feat, | |
n_feat, | |
dropout_rate, | |
kernel_size, | |
sanm_shfit=0, | |
target_size: Optional[int] = None): | |
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size, | |
sanm_shfit) | |
del self.linear_q_k_v | |
del self.fsmn_block | |
self.linear_q = nn.Linear(n_feat, n_feat) | |
self.linear_k_v = nn.Linear( | |
n_feat if target_size is None else target_size, n_feat * 2) | |
def forward_qkv( | |
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# NOTE(Mddct): here value == key | |
_ = value | |
x = query | |
b = x.size(0) | |
q = self.linear_q(x) | |
q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time1, d_k) | |
k_v = self.linear_k_v(key) | |
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1) | |
k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time2, d_k) | |
v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose( | |
1, 2) # (batch, head, time2, d_k) | |
return q_h, k_h, v_h | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
pos_emb: torch.Tensor = torch.empty(0), | |
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), | |
mask_shfit_chunk: Optional[torch.Tensor] = None | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
q, k, v = self.forward_qkv(query, key, key) | |
q = q * self.d_k**(-0.5) | |
scores = torch.matmul(q, k.transpose(-2, -1)) | |
# TODO(Mddct): support future streaming paraformer | |
cache: Optional[torch.Tensor] = None | |
return self.forward_attention(v, scores, mask), cache | |