OSUM / wenet /paraformer /attention.py
tomxxie
适配zeroGPU
568e264
from typing import Optional, Tuple
from wenet.transformer.attention import MultiHeadedAttention
from torch import nn
import math
import torch
class MultiHeadedAttentionSANM(MultiHeadedAttention):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0):
"""Construct an MultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# We assume d_v always equals d_k
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
del self.linear_q, self.linear_k, self.linear_v
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.fsmn_block = nn.Conv1d(n_feat,
n_feat,
kernel_size,
stride=1,
padding=0,
groups=n_feat,
bias=False)
# padding
self.left_padding = (kernel_size - 1) // 2
if sanm_shfit > 0:
self.left_padding = self.left_padding + sanm_shfit
self.right_padding = kernel_size - 1 - self.left_padding
self.pad_fn = nn.ConstantPad1d((self.left_padding, self.right_padding),
0.0)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = query
b, t, _ = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time1, d_k)
k = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
v = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_fsmn(self,
inputs: torch.Tensor,
mask: torch.Tensor,
mask_shfit_chunk: Optional[torch.Tensor] = None):
b, _, t, _ = inputs.size()
inputs = inputs.transpose(1, 2).view(b, t, -1)
if mask.size(2) > 0: # time2 > 0
# TODO(Mddct): make sure mask is right
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
mask = mask.transpose(1, 2) # [B,T,1]
inputs = inputs * mask
x = inputs.transpose(1, 2)
# x = torch.nn.functional.pad(x, (self.left_padding, self.right_padding),
# value=0.0,
# mode='constant')
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
return x * mask
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = self.forward_qkv(query, key, value)
if cache.size(0) > 0:
key_cache, value_cache = torch.split(cache,
cache.size(-1) // 2,
dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(Mddct): we need know fsmn_memory's cache, but paraformer is nonstreamming
# refactor later if streaming model is available
new_cache = torch.cat((k, v), dim=-1)
fsmn_memory = self.forward_fsmn(v,
mask=mask_pad,
mask_shfit_chunk=mask_shfit_chunk)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
att = self.forward_attention(v, scores, mask)
return att + fsmn_memory, new_cache
class DummyMultiHeadSANM(MultiHeadedAttentionSANM):
"""A dummy multihead attention for Paraformer befroe cross attention
"""
def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0):
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size,
sanm_shfit)
del self.linear_q_k_v
del self.linear_out
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
query = query * mask_pad.transpose(1, 2)
inputs = query
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
# TODO(Mddct): cache here for future streaming
cache: Optional[torch.Tensor] = None
x = self.fsmn_block(x)
x = x.transpose(1, 2)
if x.size(1) != inputs.size(1):
inputs = inputs[:, -1, :]
x = x + inputs
x = self.dropout(x)
x = x * mask_pad.transpose(1, 2)
return x, cache
class MultiHeadAttentionCross(MultiHeadedAttentionSANM):
def __init__(self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0,
target_size: Optional[int] = None):
super().__init__(n_head, in_feat, n_feat, dropout_rate, kernel_size,
sanm_shfit)
del self.linear_q_k_v
del self.fsmn_block
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(
n_feat if target_size is None else target_size, n_feat * 2)
def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NOTE(Mddct): here value == key
_ = value
x = query
b = x.size(0)
q = self.linear_q(x)
q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time1, d_k)
k_v = self.linear_k_v(key)
k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(
1, 2) # (batch, head, time2, d_k)
return q_h, k_h, v_h
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
mask_shfit_chunk: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
q, k, v = self.forward_qkv(query, key, key)
q = q * self.d_k**(-0.5)
scores = torch.matmul(q, k.transpose(-2, -1))
# TODO(Mddct): support future streaming paraformer
cache: Optional[torch.Tensor] = None
return self.forward_attention(v, scores, mask), cache