"""Search algorithms for transducer models.""" from typing import List from typing import Union import numpy as np import torch from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state from espnet.nets.pytorch_backend.transducer.utils import init_lm_state from espnet.nets.pytorch_backend.transducer.utils import is_prefix from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps from espnet.nets.pytorch_backend.transducer.utils import select_lm_state from espnet.nets.pytorch_backend.transducer.utils import substract from espnet.nets.transducer_decoder_interface import Hypothesis from espnet.nets.transducer_decoder_interface import NSCHypothesis from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface class BeamSearchTransducer: """Beam search implementation for transducer.""" def __init__( self, decoder: Union[TransducerDecoderInterface, torch.nn.Module], joint_network: torch.nn.Module, beam_size: int, lm: torch.nn.Module = None, lm_weight: float = 0.1, search_type: str = "default", max_sym_exp: int = 2, u_max: int = 50, nstep: int = 1, prefix_alpha: int = 1, score_norm: bool = True, nbest: int = 1, ): """Initialize transducer beam search. Args: decoder: Decoder class to use joint_network: Joint Network class beam_size: Number of hypotheses kept during search lm: LM class to use lm_weight: lm weight for soft fusion search_type: type of algorithm to use for search max_sym_exp: number of maximum symbol expansions at each time step ("tsd") u_max: maximum output sequence length ("alsd") nstep: number of maximum expansion steps at each time step ("nsc") prefix_alpha: maximum prefix length in prefix search ("nsc") score_norm: normalize final scores by length ("default") nbest: number of returned final hypothesis """ self.decoder = decoder self.joint_network = joint_network self.beam_size = beam_size self.hidden_size = decoder.dunits self.vocab_size = decoder.odim self.blank = decoder.blank if self.beam_size <= 1: self.search_algorithm = self.greedy_search elif search_type == "default": self.search_algorithm = self.default_beam_search elif search_type == "tsd": self.search_algorithm = self.time_sync_decoding elif search_type == "alsd": self.search_algorithm = self.align_length_sync_decoding elif search_type == "nsc": self.search_algorithm = self.nsc_beam_search else: raise NotImplementedError self.lm = lm self.lm_weight = lm_weight if lm is not None: self.use_lm = True self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor self.lm_layers = len(self.lm_predictor.rnn) else: self.use_lm = False self.max_sym_exp = max_sym_exp self.u_max = u_max self.nstep = nstep self.prefix_alpha = prefix_alpha self.score_norm = score_norm self.nbest = nbest def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]: """Perform beam search. Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ self.decoder.set_device(h.device) if not hasattr(self.decoder, "decoders"): self.decoder.set_data_type(h.dtype) nbest_hyps = self.search_algorithm(h) return nbest_hyps def sort_nbest( self, hyps: Union[List[Hypothesis], List[NSCHypothesis]] ) -> Union[List[Hypothesis], List[NSCHypothesis]]: """Sort hypotheses by score or score given sequence length. Args: hyps: list of hypotheses Return: hyps: sorted list of hypotheses """ if self.score_norm: hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) else: hyps.sort(key=lambda x: x.score, reverse=True) return hyps[: self.nbest] def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]: """Greedy search implementation for transformer-transducer. Args: h: Encoded speech features (T_max, D_enc) Returns: hyp: 1-best decoding results """ dec_state = self.decoder.init_state(1) hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state) cache = {} y, state, _ = self.decoder.score(hyp, cache) for i, hi in enumerate(h): ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) logp, pred = torch.max(ytu, dim=-1) if pred != self.blank: hyp.yseq.append(int(pred)) hyp.score += float(logp) hyp.dec_state = state y, state, _ = self.decoder.score(hyp, cache) return [hyp] def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]: """Beam search implementation. Args: x: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) dec_state = self.decoder.init_state(1) kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)] cache = {} for hi in h: hyps = kept_hyps kept_hyps = [] while True: max_hyp = max(hyps, key=lambda x: x.score) hyps.remove(max_hyp) y, state, lm_tokens = self.decoder.score(max_hyp, cache) ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) top_k = ytu[1:].topk(beam_k, dim=-1) kept_hyps.append( Hypothesis( score=(max_hyp.score + float(ytu[0:1])), yseq=max_hyp.yseq[:], dec_state=max_hyp.dec_state, lm_state=max_hyp.lm_state, ) ) if self.use_lm: lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens) else: lm_state = max_hyp.lm_state for logp, k in zip(*top_k): score = max_hyp.score + float(logp) if self.use_lm: score += self.lm_weight * lm_scores[0][k + 1] hyps.append( Hypothesis( score=score, yseq=max_hyp.yseq[:] + [int(k + 1)], dec_state=state, lm_state=lm_state, ) ) hyps_max = float(max(hyps, key=lambda x: x.score).score) kept_most_prob = sorted( [hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score, ) if len(kept_most_prob) >= beam: kept_hyps = kept_most_prob break return self.sort_nbest(kept_hyps) def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for hi in h: A = [] C = B h_enc = hi.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, self.use_lm, ) beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) ) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) ) if v < (self.max_sym_exp - 1): if self.use_lm: beam_lm_states = create_lm_batch_state( [c.lm_state for c in C], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(C) ) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B) def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) h_length = int(h.size(0)) u_max = min(self.u_max, (h_length - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for i in range(h_length + u_max): A = [] B_ = [] h_states = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u + 1 if t > (h_length - 1): continue B_.append(hyp) h_states.append((t, h[t])) if B_: beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) h_enc = torch.stack([h[1] for h in h_states]) beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_states = create_lm_batch_state( [b.lm_state for b in B_], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_) ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if h_states[i][0] == (h_length - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]: """N-step constrained beam search implementation. Based and modified from https://arxiv.org/pdf/2002.03577.pdf. Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet until further modifications. Note: the algorithm is not in his "complete" form but works almost as intended. Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) beam_state = self.decoder.init_state(beam) init_tokens = [ NSCHypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( init_tokens, beam_state, cache, self.use_lm, ) state = self.decoder.select_state(beam_state, 0) if self.use_lm: beam_lm_states, beam_lm_scores = self.lm.buff_predict( None, beam_lm_tokens, 1 ) lm_state = select_lm_state( beam_lm_states, 0, self.lm_layers, self.is_wordlm ) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ NSCHypothesis( yseq=[self.blank], score=0.0, dec_state=state, y=[beam_y[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for hi in h: hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) kept_hyps = [] h_enc = hi.unsqueeze(0) for j, hyp_j in enumerate(hyps[:-1]): for hyp_i in hyps[(j + 1) :]: curr_id = len(hyp_j.yseq) next_id = len(hyp_i.yseq) if ( is_prefix(hyp_j.yseq, hyp_i.yseq) and (curr_id - next_id) <= self.prefix_alpha ): ytu = torch.log_softmax( self.joint_network(hi, hyp_i.y[-1]), dim=-1 ) curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]]) for k in range(next_id, (curr_id - 1)): ytu = torch.log_softmax( self.joint_network(hi, hyp_j.y[k]), dim=-1 ) curr_score += float(ytu[hyp_j.yseq[k + 1]]) hyp_j.score = np.logaddexp(hyp_j.score, curr_score) S = [] V = [] for n in range(self.nstep): beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) for i, hyp in enumerate(hyps): S.append( NSCHypothesis( yseq=hyp.yseq[:], score=hyp.score + float(beam_logp[i, 0:1]), y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) ) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): score = hyp.score + float(logp) if self.use_lm: score += self.lm_weight * float(hyp.lm_scores[k]) V.append( NSCHypothesis( yseq=hyp.yseq[:] + [int(k)], score=score, y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) ) V.sort(key=lambda x: x.score, reverse=True) V = substract(V, hyps)[:beam] beam_state = self.decoder.create_batch_states( beam_state, [v.dec_state for v in V], [v.yseq for v in V], ) beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( V, beam_state, cache, self.use_lm, ) if self.use_lm: beam_lm_states = create_lm_batch_state( [v.lm_state for v in V], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(V) ) if n < (self.nstep - 1): for i, v in enumerate(V): v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) v.lm_scores = beam_lm_scores[i] hyps = V[:] else: beam_logp = torch.log_softmax( self.joint_network(h_enc, beam_y), dim=-1 ) for i, v in enumerate(V): if self.nstep != 1: v.score += float(beam_logp[i, 0]) v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) v.lm_scores = beam_lm_scores[i] kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(kept_hyps)