JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
2.09 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools as it
from typing import Any, Dict, List
import torch
from fairseq.data.dictionary import Dictionary
from fairseq.models.fairseq_model import FairseqModel
class BaseDecoder:
def __init__(self, tgt_dict: Dictionary) -> None:
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.blank = (
tgt_dict.index("<ctc_blank>")
if "<ctc_blank>" in tgt_dict.indices
else tgt_dict.bos()
)
if "<sep>" in tgt_dict.indices:
self.silence = tgt_dict.index("<sep>")
elif "|" in tgt_dict.indices:
self.silence = tgt_dict.index("|")
else:
self.silence = tgt_dict.eos()
def generate(
self, models: List[FairseqModel], sample: Dict[str, Any], **unused
) -> List[List[Dict[str, torch.LongTensor]]]:
encoder_input = {
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
}
emissions = self.get_emissions(models, encoder_input)
return self.decode(emissions)
def get_emissions(
self,
models: List[FairseqModel],
encoder_input: Dict[str, Any],
) -> torch.FloatTensor:
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out)
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
def get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def decode(
self,
emissions: torch.FloatTensor,
) -> List[List[Dict[str, torch.LongTensor]]]:
raise NotImplementedError