# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Tuple import numpy as np import torch import torchaudio.functional as F def remove_duplicates_and_blank(hyp: List[int], blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp def replace_duplicates_with_blank(hyp: List[int], blank_id: int = 0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): new_hyp.append(hyp[cur]) prev = cur cur += 1 while cur < len( hyp) and hyp[cur] == hyp[prev] and hyp[cur] != blank_id: new_hyp.append(blank_id) cur += 1 return new_hyp def gen_ctc_peak_time(hyp: List[int], blank_id: int = 0) -> List[int]: times = [] cur = 0 while cur < len(hyp): if hyp[cur] != blank_id: times.append(cur) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return times def gen_timestamps_from_peak( peaks: List[int], max_duration: float, frame_rate: float = 0.04, max_token_duration: float = 1.0, ) -> List[Tuple[float, float]]: """ Args: peaks: ctc peaks time stamp max_duration: max_duration of the sentence frame_rate: frame rate of every time stamp, in seconds max_token_duration: max duration of the token, in seconds Returns: list(start, end) of each token """ times = [] half_max = max_token_duration / 2 for i in range(len(peaks)): if i == 0: start = max(0, peaks[0] * frame_rate - half_max) else: start = max((peaks[i - 1] + peaks[i]) / 2 * frame_rate, peaks[i] * frame_rate - half_max) if i == len(peaks) - 1: end = min(max_duration, peaks[-1] * frame_rate + half_max) else: end = min((peaks[i] + peaks[i + 1]) / 2 * frame_rate, peaks[i] * frame_rate + half_max) times.append((start, end)) return times def insert_blank(label, blank_id=0): """Insert blank token between every two label token.""" label = np.expand_dims(label, 1) blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) label = label.reshape(-1) label = np.append(label, label[0]) return label def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: """ctc forced alignment. Args: torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) torch.Tensor y: id sequence tensor 1d tensor (L) int blank_id: blank symbol index Returns: torch.Tensor: alignment result """ ctc_probs = ctc_probs[None].cpu() y = y[None].cpu() alignments, _ = F.forced_align(ctc_probs, y, blank=blank_id) return alignments[0] def get_blank_id(configs, symbol_table): if 'ctc_conf' not in configs: configs['ctc_conf'] = {} if '' in symbol_table: if 'ctc_blank_id' in configs['ctc_conf']: assert configs['ctc_conf']['ctc_blank_id'] == symbol_table[ ''] else: configs['ctc_conf']['ctc_blank_id'] = symbol_table[''] else: assert 'ctc_blank_id' in configs[ 'ctc_conf'], "PLZ set ctc_blank_id in yaml" return configs, configs['ctc_conf']['ctc_blank_id']