|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import six |
|
import torch |
|
import numpy as np |
|
|
|
|
|
def sequence_mask( |
|
lengths, |
|
maxlen: Optional[int] = None, |
|
dtype: torch.dtype = torch.float32, |
|
device: Optional[torch.device] = None, |
|
) -> torch.Tensor: |
|
if maxlen is None: |
|
maxlen = lengths.max() |
|
row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
|
matrix = torch.unsqueeze(lengths, dim=-1) |
|
mask = row_vector < matrix |
|
mask = mask.detach() |
|
|
|
return mask.type(dtype).to(device) if device is not None else mask.type(dtype) |
|
|
|
|
|
def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): |
|
"""End detection. |
|
|
|
described in Eq. (50) of S. Watanabe et al |
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" |
|
|
|
:param ended_hyps: |
|
:param i: |
|
:param M: |
|
:param d_end: |
|
:return: |
|
""" |
|
if len(ended_hyps) == 0: |
|
return False |
|
count = 0 |
|
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] |
|
for m in six.moves.range(M): |
|
|
|
hyp_length = i - m |
|
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] |
|
if len(hyps_same_length) > 0: |
|
best_hyp_same_length = sorted( |
|
hyps_same_length, key=lambda x: x["score"], reverse=True |
|
)[0] |
|
if best_hyp_same_length["score"] - best_hyp["score"] < d_end: |
|
count += 1 |
|
|
|
if count == M: |
|
return True |
|
else: |
|
return False |
|
|