import math from typing import Any, List, Optional, Tuple, Union import torch from wenet.transformer.search import DecodeResult from wenet.utils.mask import (make_non_pad_mask, mask_finished_preds, mask_finished_scores) def _isChinese(ch: str): if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@': return True return False def _isAllChinese(word: Union[List[Any], str]): word_lists = [] for i in word: cur = i.replace(' ', '') cur = cur.replace('', '') cur = cur.replace('', '') cur = cur.replace('', '') cur = cur.replace('', '') word_lists.append(cur) if len(word_lists) == 0: return False for ch in word_lists: if _isChinese(ch) is False: return False return True def _isAllAlpha(word: Union[List[Any], str]): word_lists = [] for i in word: cur = i.replace(' ', '') cur = cur.replace('', '') cur = cur.replace('', '') cur = cur.replace('', '') cur = cur.replace('', '') word_lists.append(cur) if len(word_lists) == 0: return False for ch in word_lists: if ch.isalpha() is False and ch != "'": return False elif ch.isalpha() is True and _isChinese(ch) is True: return False return True def paraformer_beautify_result(tokens: List[str]) -> str: middle_lists = [] word_lists = [] word_item = '' # wash words lists for token in tokens: if token in ['', '', '']: continue else: middle_lists.append(token) # all chinese characters if _isAllChinese(middle_lists): for _, ch in enumerate(middle_lists): word_lists.append(ch.replace(' ', '')) # all alpha characters elif _isAllAlpha(middle_lists): for _, ch in enumerate(middle_lists): word = '' if '@@' in ch: word = ch.replace('@@', '') word_item += word else: word_item += ch word_lists.append(word_item) word_lists.append(' ') word_item = '' # mix characters else: alpha_blank = False for _, ch in enumerate(middle_lists): word = '' if _isAllChinese(ch): if alpha_blank is True: word_lists.pop() word_lists.append(ch) alpha_blank = False elif '@@' in ch: word = ch.replace('@@', '') word_item += word alpha_blank = False elif _isAllAlpha(ch): word_item += ch word_lists.append(word_item) word_lists.append(' ') word_item = '' alpha_blank = True else: word_lists.append(ch) alpha_blank = False return ''.join(word_lists).strip() def gen_timestamps_from_peak(cif_peaks: List[int], num_frames: int, frame_rate=0.02): START_END_THRESHOLD = 5 MAX_TOKEN_DURATION = 14 force_time_shift = -0.5 fire_place = [peak + force_time_shift for peak in cif_peaks] times = [] for i in range(len(fire_place) - 1): if MAX_TOKEN_DURATION < 0 or fire_place[ i + 1] - fire_place[i] <= MAX_TOKEN_DURATION: times.append( [fire_place[i] * frame_rate, fire_place[i + 1] * frame_rate]) else: split = fire_place[i] + MAX_TOKEN_DURATION times.append([fire_place[i] * frame_rate, split * frame_rate]) if len(times) > 0: if num_frames - fire_place[-1] > START_END_THRESHOLD: end = (num_frames + fire_place[-1]) * 0.5 times[-1][1] = end * frame_rate times.append([end * frame_rate, num_frames * frame_rate]) else: times[-1][1] = num_frames * frame_rate return times def paraformer_greedy_search( decoder_out: torch.Tensor, decoder_out_lens: torch.Tensor, cif_peaks: Optional[torch.Tensor] = None) -> List[DecodeResult]: batch_size = decoder_out.shape[0] maxlen = decoder_out.size(1) topk_prob, topk_index = decoder_out.topk(1, dim=2) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) topk_prob = topk_prob.view(batch_size, maxlen) results: List[DecodeResult] = [] topk_index = topk_index.cpu().tolist() topk_prob = topk_prob.cpu().tolist() decoder_out_lens = decoder_out_lens.cpu().numpy() for (i, hyp) in enumerate(topk_index): confidence = 0.0 tokens_confidence = [] lens = decoder_out_lens[i] for logp in topk_prob[i][:lens]: tokens_confidence.append(math.exp(logp)) confidence += logp r = DecodeResult(hyp[:lens], tokens_confidence=tokens_confidence, confidence=math.exp(confidence / lens)) results.append(r) if cif_peaks is not None: for (b, peaks) in enumerate(cif_peaks): result = results[b] times = [] n_token = 0 for (i, peak) in enumerate(peaks): if n_token >= len(result.tokens): break if peak > 1 - 1e-4: times.append(i) n_token += 1 result.times = times assert len(result.times) == len(result.tokens) return results def paraformer_beam_search(decoder_out: torch.Tensor, decoder_out_lens: torch.Tensor, beam_size: int = 10, eos: int = -1) -> List[DecodeResult]: mask = make_non_pad_mask(decoder_out_lens) indices, _ = _batch_beam_search(decoder_out, mask, beam_size=beam_size, eos=eos) best_hyps = indices[:, 0, :].cpu() decoder_out_lens = decoder_out_lens.cpu() results = [] # TODO(Mddct): scores, times etc for (i, hyp) in enumerate(best_hyps.tolist()): r = DecodeResult(hyp[:decoder_out_lens.numpy()[i]]) results.append(r) return results def _batch_beam_search( logit: torch.Tensor, masks: torch.Tensor, beam_size: int = 10, eos: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform batch beam search Args: logit: shape (batch_size, seq_length, vocab_size) masks: shape (batch_size, seq_length) beam_size: beam size Returns: indices: shape (batch_size, beam_size, seq_length) log_prob: shape (batch_size, beam_size) """ batch_size, seq_length, vocab_size = logit.shape masks = ~masks # beam search with torch.no_grad(): # b,t,v log_post = torch.nn.functional.log_softmax(logit, dim=-1) # b,k log_prob, indices = log_post[:, 0, :].topk(beam_size, sorted=True) end_flag = torch.eq(masks[:, 0], 1).view(-1, 1) # mask predictor and scores if end log_prob = mask_finished_scores(log_prob, end_flag) indices = mask_finished_preds(indices, end_flag, eos) # b,k,1 indices = indices.unsqueeze(-1) for i in range(1, seq_length): # b,v scores = mask_finished_scores(log_post[:, i, :], end_flag) # b,v -> b,k,v topk_scores = scores.unsqueeze(1).repeat(1, beam_size, 1) # b,k,1 + b,k,v -> b,k,v top_k_logp = log_prob.unsqueeze(-1) + topk_scores # b,k,v -> b,k*v -> b,k log_prob, top_k_index = top_k_logp.view(batch_size, -1).topk(beam_size, sorted=True) index = mask_finished_preds(top_k_index, end_flag, eos) indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1) end_flag = torch.eq(masks[:, i], 1).view(-1, 1) indices = torch.fmod(indices, vocab_size) return indices, log_prob