Hecheng0625's picture
Upload 409 files
c968fc3 verified
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from typing import Optional
import six
import torch
import numpy as np
def sequence_mask(
lengths,
maxlen: Optional[int] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param d_end:
:return:
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
for m in six.moves.range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True
)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < d_end:
count += 1
if count == M:
return True
else:
return False