Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) | |
# 2023 ASLP@NWPU (authors: He Wang, Fan Yu) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified from ESPnet(https://github.com/espnet/espnet) and | |
# FunASR(https://github.com/alibaba-damo-academy/FunASR) | |
from typing import Dict, List, Optional, Tuple | |
import torch | |
from wenet.paraformer.cif import Cif, cif_without_hidden | |
from wenet.paraformer.layers import SanmDecoder, SanmEncoder | |
from wenet.paraformer.layers import LFR | |
from wenet.paraformer.search import (paraformer_beam_search, | |
paraformer_greedy_search) | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.decoder import TransformerDecoder | |
from wenet.transformer.encoder import BaseEncoder | |
from wenet.transformer.search import (DecodeResult, ctc_greedy_search, | |
ctc_prefix_beam_search) | |
from wenet.utils.common import IGNORE_ID, add_sos_eos, th_accuracy | |
from wenet.utils.mask import make_non_pad_mask | |
class Predictor(torch.nn.Module): | |
def __init__( | |
self, | |
idim, | |
l_order, | |
r_order, | |
threshold=1.0, | |
dropout=0.1, | |
smooth_factor=1.0, | |
noise_threshold=0.0, | |
tail_threshold=0.45, | |
residual=True, | |
cnn_groups=0, | |
smooth_factor2=0.25, | |
noise_threshold2=0.01, | |
upsample_times=3, | |
): | |
super().__init__() | |
self.predictor = Cif(idim, l_order, r_order, threshold, dropout, | |
smooth_factor, noise_threshold, tail_threshold, | |
residual, cnn_groups) | |
# accurate timestamp branch | |
self.smooth_factor2 = smooth_factor2 | |
self.noise_threshold2 = noise_threshold | |
self.upsample_times = upsample_times | |
self.noise_threshold2 = noise_threshold2 | |
self.tp_upsample_cnn = torch.nn.ConvTranspose1d( | |
idim, idim, self.upsample_times, self.upsample_times) | |
self.tp_blstm = torch.nn.LSTM(idim, | |
idim, | |
1, | |
bias=True, | |
batch_first=True, | |
dropout=0.0, | |
bidirectional=True) | |
self.tp_output = torch.nn.Linear(idim * 2, 1) | |
def forward(self, | |
hidden, | |
target_label: Optional[torch.Tensor] = None, | |
mask: torch.Tensor = torch.tensor(0), | |
ignore_id: int = -1, | |
mask_chunk_predictor: Optional[torch.Tensor] = None, | |
target_label_length: Optional[torch.Tensor] = None): | |
acoustic_embeds, token_num, alphas, cif_peak = self.predictor( | |
hidden, target_label, mask, ignore_id, mask_chunk_predictor, | |
target_label_length) | |
output, (_, _) = self.tp_blstm( | |
self.tp_upsample_cnn(hidden.transpose(1, 2)).transpose(1, 2)) | |
tp_alphas = torch.sigmoid(self.tp_output(output)) | |
tp_alphas = torch.nn.functional.relu(tp_alphas * self.smooth_factor2 - | |
self.noise_threshold2) | |
mask = mask.repeat(1, self.upsample_times, | |
1).transpose(-1, | |
-2).reshape(tp_alphas.shape[0], -1) | |
mask = mask.unsqueeze(-1) | |
tp_alphas = tp_alphas * mask | |
tp_alphas = tp_alphas.squeeze(-1) | |
tp_token_num = tp_alphas.sum(-1) | |
return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, tp_token_num | |
class Paraformer(ASRModel): | |
""" Paraformer: Fast and Accurate Parallel Transformer for | |
Non-autoregressive End-to-End Speech Recognition | |
see https://arxiv.org/pdf/2206.08317.pdf | |
""" | |
def __init__(self, | |
vocab_size: int, | |
encoder: BaseEncoder, | |
decoder: TransformerDecoder, | |
predictor: Predictor, | |
ctc: CTC, | |
ctc_weight: float = 0.5, | |
ignore_id: int = -1, | |
lsm_weight: float = 0, | |
length_normalized_loss: bool = False, | |
sampler: bool = True, | |
sampling_ratio: float = 0.75, | |
add_eos: bool = True, | |
special_tokens: Optional[Dict] = None, | |
apply_non_blank_embedding: bool = False): | |
assert isinstance(encoder, | |
SanmEncoder), isinstance(decoder, SanmDecoder) | |
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, | |
IGNORE_ID, 0.0, lsm_weight, length_normalized_loss, | |
None, apply_non_blank_embedding) | |
if ctc_weight == 0.0: | |
del ctc | |
self.predictor = predictor | |
self.lfr = LFR() | |
assert special_tokens is not None | |
self.sos = special_tokens['<sos>'] | |
self.eos = special_tokens['<eos>'] | |
self.sampler = sampler | |
self.sampling_ratio = sampling_ratio | |
if sampler: | |
self.embed = torch.nn.Embedding(vocab_size, encoder.output_size()) | |
# NOTE(Mddct): add eos in tail of labels for predictor | |
# eg: | |
# gt: 你 好 we@@ net | |
# labels: 你 好 we@@ net eos | |
self.add_eos = add_eos | |
def forward( | |
self, | |
batch: Dict, | |
device: torch.device, | |
) -> Dict[str, Optional[torch.Tensor]]: | |
"""Frontend + Encoder + Predictor + Decoder + Calc loss | |
""" | |
speech = batch['feats'].to(device) | |
speech_lengths = batch['feats_lengths'].to(device) | |
text = batch['target'].to(device) | |
text_lengths = batch['target_lengths'].to(device) | |
# 0 encoder | |
encoder_out, encoder_out_mask = self._forward_encoder( | |
speech, speech_lengths) | |
# 1 predictor | |
ys_pad, ys_pad_lens = text, text_lengths | |
if self.add_eos: | |
_, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) | |
ys_pad_lens = text_lengths + 1 | |
acoustic_embd, token_num, _, _, _, tp_token_num = self.predictor( | |
encoder_out, ys_pad, encoder_out_mask, self.ignore_id) | |
# 2 decoder with sampler | |
# TODO(Mddct): support mwer here | |
acoustic_embd = self._sampler( | |
encoder_out, | |
encoder_out_mask, | |
ys_pad, | |
ys_pad_lens, | |
acoustic_embd, | |
) | |
# 3 loss | |
# 3.1 ctc branhch | |
loss_ctc: Optional[torch.Tensor] = None | |
if self.ctc_weight != 0.0: | |
loss_ctc, _ = self._forward_ctc(encoder_out, encoder_out_mask, | |
text, text_lengths) | |
# 3.2 quantity loss for cif | |
loss_quantity = torch.nn.functional.l1_loss( | |
token_num, | |
ys_pad_lens.to(token_num.dtype), | |
reduction='sum', | |
) | |
loss_quantity = loss_quantity / ys_pad_lens.sum().to(token_num.dtype) | |
loss_quantity_tp = torch.nn.functional.l1_loss( | |
tp_token_num, ys_pad_lens.to(token_num.dtype), | |
reduction='sum') / ys_pad_lens.sum().to(token_num.dtype) | |
loss_decoder, acc_att = self._calc_att_loss(encoder_out, | |
encoder_out_mask, ys_pad, | |
acoustic_embd, ys_pad_lens) | |
loss = loss_decoder | |
if loss_ctc is not None: | |
loss = loss + self.ctc_weight * loss_ctc | |
loss = loss + loss_quantity + loss_quantity_tp | |
return { | |
"loss": loss, | |
"loss_ctc": loss_ctc, | |
"loss_decoder": loss_decoder, | |
"loss_quantity": loss_quantity, | |
"loss_quantity_tp": loss_quantity_tp, | |
"th_accuracy": acc_att, | |
} | |
def _calc_att_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_emb: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
infos: Dict[str, List[str]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
decoder_out, _, _ = self.decoder(encoder_out, encoder_mask, ys_pad_emb, | |
ys_pad_lens) | |
loss_att = self.criterion_att(decoder_out, ys_pad) | |
acc_att = th_accuracy(decoder_out.view(-1, self.vocab_size), | |
ys_pad, | |
ignore_label=self.ignore_id) | |
return loss_att, acc_att | |
def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens, | |
pre_acoustic_embeds): | |
device = encoder_out.device | |
B, _ = ys_pad.size() | |
tgt_mask = make_non_pad_mask(ys_pad_lens) | |
# zero the ignore id | |
ys_pad = ys_pad * tgt_mask | |
ys_pad_embed = self.embed(ys_pad) # [B, T, L] | |
with torch.no_grad(): | |
decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask, | |
pre_acoustic_embeds, ys_pad_lens) | |
pred_tokens = decoder_out.argmax(-1) | |
nonpad_positions = tgt_mask | |
same_num = ((pred_tokens == ys_pad) * nonpad_positions).sum(1) | |
input_mask = torch.ones_like( | |
nonpad_positions, | |
device=device, | |
dtype=tgt_mask.dtype, | |
) | |
for li in range(B): | |
target_num = (ys_pad_lens[li] - | |
same_num[li].sum()).float() * self.sampling_ratio | |
target_num = target_num.long() | |
if target_num > 0: | |
input_mask[li].scatter_( | |
dim=0, | |
index=torch.randperm(ys_pad_lens[li], | |
device=device)[:target_num], | |
value=0, | |
) | |
input_mask = torch.where(input_mask > 0, 1, 0) | |
input_mask = input_mask * tgt_mask | |
input_mask_expand = input_mask.unsqueeze(2) # [B, T, 1] | |
sematic_embeds = torch.where(input_mask_expand == 1, | |
pre_acoustic_embeds, ys_pad_embed) | |
# zero out the paddings | |
return sematic_embeds * tgt_mask.unsqueeze(2) | |
def _forward_encoder( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = 0, | |
num_decoding_left_chunks: int = -1, | |
simulate_streaming: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# TODO(Mddct): support chunk by chunk | |
assert simulate_streaming is False | |
features, features_lens = self.lfr(speech, speech_lengths) | |
features_lens = features_lens.to(speech_lengths.dtype) | |
encoder_out, encoder_out_mask = self.encoder(features, features_lens, | |
decoding_chunk_size, | |
num_decoding_left_chunks) | |
return encoder_out, encoder_out_mask | |
def forward_paraformer( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
res = self._forward_paraformer(speech, speech_lengths) | |
return res['decoder_out'], res['decoder_out_lens'], res['tp_alphas'] | |
def forward_encoder_chunk( | |
self, | |
xs: torch.Tensor, | |
offset: int, | |
required_cache_size: int, | |
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# TODO(Mddct): fix | |
xs_lens = torch.tensor(xs.size(1), dtype=torch.int) | |
encoder_out, _ = self._forward_encoder(xs, xs_lens) | |
return encoder_out, att_cache, cnn_cache | |
def forward_cif_peaks(self, alphas: torch.Tensor, | |
token_nums: torch.Tensor) -> torch.Tensor: | |
cif2_token_nums = alphas.sum(-1) | |
scale_alphas = alphas / (cif2_token_nums / token_nums).unsqueeze(1) | |
peaks = cif_without_hidden(scale_alphas, | |
self.predictor.predictor.threshold - 1e-4) | |
return peaks | |
def _forward_paraformer( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
) -> Dict[str, torch.Tensor]: | |
# encoder | |
encoder_out, encoder_out_mask = self._forward_encoder( | |
speech, speech_lengths, decoding_chunk_size, | |
num_decoding_left_chunks) | |
# cif predictor | |
acoustic_embed, token_num, _, _, tp_alphas, _ = self.predictor( | |
encoder_out, mask=encoder_out_mask) | |
token_num = token_num.floor().to(speech_lengths.dtype) | |
# decoder | |
decoder_out, _, _ = self.decoder(encoder_out, encoder_out_mask, | |
acoustic_embed, token_num) | |
decoder_out = decoder_out.log_softmax(dim=-1) | |
return { | |
"encoder_out": encoder_out, | |
"encoder_out_mask": encoder_out_mask, | |
"decoder_out": decoder_out, | |
"tp_alphas": tp_alphas, | |
"decoder_out_lens": token_num | |
} | |
def decode( | |
self, | |
methods: List[str], | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
beam_size: int, | |
decoding_chunk_size: int = -1, | |
num_decoding_left_chunks: int = -1, | |
ctc_weight: float = 0, | |
simulate_streaming: bool = False, | |
reverse_weight: float = 0, | |
context_graph=None, | |
blank_id: int = 0, | |
blank_penalty: float = 0.0, | |
length_penalty: float = 0.0, | |
infos: Dict[str, List[str]] = None, | |
) -> Dict[str, List[DecodeResult]]: | |
res = self._forward_paraformer(speech, speech_lengths, | |
decoding_chunk_size, | |
num_decoding_left_chunks) | |
encoder_out, encoder_mask, decoder_out, decoder_out_lens, tp_alphas = res[ | |
'encoder_out'], res['encoder_out_mask'], res['decoder_out'], res[ | |
'decoder_out_lens'], res['tp_alphas'] | |
peaks = self.forward_cif_peaks(tp_alphas, decoder_out_lens) | |
results = {} | |
if 'paraformer_greedy_search' in methods: | |
assert decoder_out is not None | |
assert decoder_out_lens is not None | |
paraformer_greedy_result = paraformer_greedy_search( | |
decoder_out, decoder_out_lens, peaks) | |
results['paraformer_greedy_search'] = paraformer_greedy_result | |
if 'paraformer_beam_search' in methods: | |
assert decoder_out is not None | |
assert decoder_out_lens is not None | |
paraformer_beam_result = paraformer_beam_search( | |
decoder_out, | |
decoder_out_lens, | |
beam_size=beam_size, | |
eos=self.eos) | |
results['paraformer_beam_search'] = paraformer_beam_result | |
if 'ctc_greedy_search' in methods or 'ctc_prefix_beam_search' in methods: | |
ctc_probs = self.ctc_logprobs(encoder_out, blank_penalty, blank_id) | |
encoder_lens = encoder_mask.squeeze(1).sum(1) | |
if 'ctc_greedy_search' in methods: | |
results['ctc_greedy_search'] = ctc_greedy_search( | |
ctc_probs, encoder_lens, blank_id) | |
if 'ctc_prefix_beam_search' in methods: | |
ctc_prefix_result = ctc_prefix_beam_search( | |
ctc_probs, encoder_lens, beam_size, context_graph, | |
blank_id) | |
results['ctc_prefix_beam_search'] = ctc_prefix_result | |
return results | |