import math
from typing import Any, List, Optional, Tuple, Union
import torch
from wenet.transformer.search import DecodeResult
from wenet.utils.mask import (make_non_pad_mask, mask_finished_preds,
mask_finished_scores)
def _isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039' or ch == '@':
return True
return False
def _isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if _isChinese(ch) is False:
return False
return True
def _isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
cur = cur.replace('', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and _isChinese(ch) is True:
return False
return True
def paraformer_beautify_result(tokens: List[str]) -> str:
middle_lists = []
word_lists = []
word_item = ''
# wash words lists
for token in tokens:
if token in ['', '', '']:
continue
else:
middle_lists.append(token)
# all chinese characters
if _isAllChinese(middle_lists):
for _, ch in enumerate(middle_lists):
word_lists.append(ch.replace(' ', ''))
# all alpha characters
elif _isAllAlpha(middle_lists):
for _, ch in enumerate(middle_lists):
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
# mix characters
else:
alpha_blank = False
for _, ch in enumerate(middle_lists):
word = ''
if _isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
elif _isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
else:
word_lists.append(ch)
alpha_blank = False
return ''.join(word_lists).strip()
def gen_timestamps_from_peak(cif_peaks: List[int],
num_frames: int,
frame_rate=0.02):
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 14
force_time_shift = -0.5
fire_place = [peak + force_time_shift for peak in cif_peaks]
times = []
for i in range(len(fire_place) - 1):
if MAX_TOKEN_DURATION < 0 or fire_place[
i + 1] - fire_place[i] <= MAX_TOKEN_DURATION:
times.append(
[fire_place[i] * frame_rate, fire_place[i + 1] * frame_rate])
else:
split = fire_place[i] + MAX_TOKEN_DURATION
times.append([fire_place[i] * frame_rate, split * frame_rate])
if len(times) > 0:
if num_frames - fire_place[-1] > START_END_THRESHOLD:
end = (num_frames + fire_place[-1]) * 0.5
times[-1][1] = end * frame_rate
times.append([end * frame_rate, num_frames * frame_rate])
else:
times[-1][1] = num_frames * frame_rate
return times
def paraformer_greedy_search(
decoder_out: torch.Tensor,
decoder_out_lens: torch.Tensor,
cif_peaks: Optional[torch.Tensor] = None) -> List[DecodeResult]:
batch_size = decoder_out.shape[0]
maxlen = decoder_out.size(1)
topk_prob, topk_index = decoder_out.topk(1, dim=2)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
topk_prob = topk_prob.view(batch_size, maxlen)
results: List[DecodeResult] = []
topk_index = topk_index.cpu().tolist()
topk_prob = topk_prob.cpu().tolist()
decoder_out_lens = decoder_out_lens.cpu().numpy()
for (i, hyp) in enumerate(topk_index):
confidence = 0.0
tokens_confidence = []
lens = decoder_out_lens[i]
for logp in topk_prob[i][:lens]:
tokens_confidence.append(math.exp(logp))
confidence += logp
r = DecodeResult(hyp[:lens],
tokens_confidence=tokens_confidence,
confidence=math.exp(confidence / lens))
results.append(r)
if cif_peaks is not None:
for (b, peaks) in enumerate(cif_peaks):
result = results[b]
times = []
n_token = 0
for (i, peak) in enumerate(peaks):
if n_token >= len(result.tokens):
break
if peak > 1 - 1e-4:
times.append(i)
n_token += 1
result.times = times
assert len(result.times) == len(result.tokens)
return results
def paraformer_beam_search(decoder_out: torch.Tensor,
decoder_out_lens: torch.Tensor,
beam_size: int = 10,
eos: int = -1) -> List[DecodeResult]:
mask = make_non_pad_mask(decoder_out_lens)
indices, _ = _batch_beam_search(decoder_out,
mask,
beam_size=beam_size,
eos=eos)
best_hyps = indices[:, 0, :].cpu()
decoder_out_lens = decoder_out_lens.cpu()
results = []
# TODO(Mddct): scores, times etc
for (i, hyp) in enumerate(best_hyps.tolist()):
r = DecodeResult(hyp[:decoder_out_lens.numpy()[i]])
results.append(r)
return results
def _batch_beam_search(
logit: torch.Tensor,
masks: torch.Tensor,
beam_size: int = 10,
eos: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Perform batch beam search
Args:
logit: shape (batch_size, seq_length, vocab_size)
masks: shape (batch_size, seq_length)
beam_size: beam size
Returns:
indices: shape (batch_size, beam_size, seq_length)
log_prob: shape (batch_size, beam_size)
"""
batch_size, seq_length, vocab_size = logit.shape
masks = ~masks
# beam search
with torch.no_grad():
# b,t,v
log_post = torch.nn.functional.log_softmax(logit, dim=-1)
# b,k
log_prob, indices = log_post[:, 0, :].topk(beam_size, sorted=True)
end_flag = torch.eq(masks[:, 0], 1).view(-1, 1)
# mask predictor and scores if end
log_prob = mask_finished_scores(log_prob, end_flag)
indices = mask_finished_preds(indices, end_flag, eos)
# b,k,1
indices = indices.unsqueeze(-1)
for i in range(1, seq_length):
# b,v
scores = mask_finished_scores(log_post[:, i, :], end_flag)
# b,v -> b,k,v
topk_scores = scores.unsqueeze(1).repeat(1, beam_size, 1)
# b,k,1 + b,k,v -> b,k,v
top_k_logp = log_prob.unsqueeze(-1) + topk_scores
# b,k,v -> b,k*v -> b,k
log_prob, top_k_index = top_k_logp.view(batch_size,
-1).topk(beam_size,
sorted=True)
index = mask_finished_preds(top_k_index, end_flag, eos)
indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
end_flag = torch.eq(masks[:, i], 1).view(-1, 1)
indices = torch.fmod(indices, vocab_size)
return indices, log_prob