# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import itertools as it from typing import Any, Dict, List import torch from fairseq.data.dictionary import Dictionary from fairseq.models.fairseq_model import FairseqModel class BaseDecoder: def __init__(self, tgt_dict: Dictionary) -> None: self.tgt_dict = tgt_dict self.vocab_size = len(tgt_dict) self.blank = ( tgt_dict.index("") if "" in tgt_dict.indices else tgt_dict.bos() ) if "" in tgt_dict.indices: self.silence = tgt_dict.index("") elif "|" in tgt_dict.indices: self.silence = tgt_dict.index("|") else: self.silence = tgt_dict.eos() def generate( self, models: List[FairseqModel], sample: Dict[str, Any], **unused ) -> List[List[Dict[str, torch.LongTensor]]]: encoder_input = { k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" } emissions = self.get_emissions(models, encoder_input) return self.decode(emissions) def get_emissions( self, models: List[FairseqModel], encoder_input: Dict[str, Any], ) -> torch.FloatTensor: model = models[0] encoder_out = model(**encoder_input) if hasattr(model, "get_logits"): emissions = model.get_logits(encoder_out) else: emissions = model.get_normalized_probs(encoder_out, log_probs=True) return emissions.transpose(0, 1).float().cpu().contiguous() def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: idxs = (g[0] for g in it.groupby(idxs)) idxs = filter(lambda x: x != self.blank, idxs) return torch.LongTensor(list(idxs)) def decode( self, emissions: torch.FloatTensor, ) -> List[List[Dict[str, torch.LongTensor]]]: raise NotImplementedError