Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from concurrent.futures import ThreadPoolExecutor | |
import logging | |
from omegaconf import MISSING | |
import os | |
import torch | |
from typing import Optional | |
import warnings | |
from dataclasses import dataclass | |
from fairseq.dataclass import FairseqDataclass | |
from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi | |
logger = logging.getLogger(__name__) | |
class KaldiDecoderConfig(FairseqDataclass): | |
hlg_graph_path: Optional[str] = None | |
output_dict: str = MISSING | |
kaldi_initializer_config: Optional[KaldiInitializerConfig] = None | |
acoustic_scale: float = 0.5 | |
max_active: int = 10000 | |
beam_delta: float = 0.5 | |
hash_ratio: float = 2.0 | |
is_lattice: bool = False | |
lattice_beam: float = 10.0 | |
prune_interval: int = 25 | |
determinize_lattice: bool = True | |
prune_scale: float = 0.1 | |
max_mem: int = 0 | |
phone_determinize: bool = True | |
word_determinize: bool = True | |
minimize: bool = True | |
num_threads: int = 1 | |
class KaldiDecoder(object): | |
def __init__( | |
self, | |
cfg: KaldiDecoderConfig, | |
beam: int, | |
nbest: int = 1, | |
): | |
try: | |
from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer | |
from kaldi.base import set_verbose_level | |
from kaldi.decoder import ( | |
FasterDecoder, | |
FasterDecoderOptions, | |
LatticeFasterDecoder, | |
LatticeFasterDecoderOptions, | |
) | |
from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions | |
from kaldi.fstext import read_fst_kaldi, SymbolTable | |
except: | |
warnings.warn( | |
"pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi" | |
) | |
# set_verbose_level(2) | |
self.acoustic_scale = cfg.acoustic_scale | |
self.nbest = nbest | |
if cfg.hlg_graph_path is None: | |
assert ( | |
cfg.kaldi_initializer_config is not None | |
), "Must provide hlg graph path or kaldi initializer config" | |
cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config) | |
assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path | |
if cfg.is_lattice: | |
self.dec_cls = LatticeFasterDecoder | |
opt_cls = LatticeFasterDecoderOptions | |
self.rec_cls = LatticeFasterRecognizer | |
else: | |
assert self.nbest == 1, "nbest > 1 requires lattice decoder" | |
self.dec_cls = FasterDecoder | |
opt_cls = FasterDecoderOptions | |
self.rec_cls = FasterRecognizer | |
self.decoder_options = opt_cls() | |
self.decoder_options.beam = beam | |
self.decoder_options.max_active = cfg.max_active | |
self.decoder_options.beam_delta = cfg.beam_delta | |
self.decoder_options.hash_ratio = cfg.hash_ratio | |
if cfg.is_lattice: | |
self.decoder_options.lattice_beam = cfg.lattice_beam | |
self.decoder_options.prune_interval = cfg.prune_interval | |
self.decoder_options.determinize_lattice = cfg.determinize_lattice | |
self.decoder_options.prune_scale = cfg.prune_scale | |
det_opts = DeterminizeLatticePhonePrunedOptions() | |
det_opts.max_mem = cfg.max_mem | |
det_opts.phone_determinize = cfg.phone_determinize | |
det_opts.word_determinize = cfg.word_determinize | |
det_opts.minimize = cfg.minimize | |
self.decoder_options.det_opts = det_opts | |
self.output_symbols = {} | |
with open(cfg.output_dict, "r") as f: | |
for line in f: | |
items = line.rstrip().split() | |
assert len(items) == 2 | |
self.output_symbols[int(items[1])] = items[0] | |
logger.info(f"Loading FST from {cfg.hlg_graph_path}") | |
self.fst = read_fst_kaldi(cfg.hlg_graph_path) | |
self.symbol_table = SymbolTable.read_text(cfg.output_dict) | |
self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads) | |
def generate(self, models, sample, **unused): | |
"""Generate a batch of inferences.""" | |
# model.forward normally channels prev_output_tokens into the decoder | |
# separately, but SequenceGenerator directly calls model.encoder | |
encoder_input = { | |
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" | |
} | |
emissions, padding = self.get_emissions(models, encoder_input) | |
return self.decode(emissions, padding) | |
def get_emissions(self, models, encoder_input): | |
"""Run encoder and normalize emissions""" | |
model = models[0] | |
all_encoder_out = [m(**encoder_input) for m in models] | |
if len(all_encoder_out) > 1: | |
if "encoder_out" in all_encoder_out[0]: | |
encoder_out = { | |
"encoder_out": sum(e["encoder_out"] for e in all_encoder_out) | |
/ len(all_encoder_out), | |
"encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"], | |
} | |
padding = encoder_out["encoder_padding_mask"] | |
else: | |
encoder_out = { | |
"logits": sum(e["logits"] for e in all_encoder_out) | |
/ len(all_encoder_out), | |
"padding_mask": all_encoder_out[0]["padding_mask"], | |
} | |
padding = encoder_out["padding_mask"] | |
else: | |
encoder_out = all_encoder_out[0] | |
padding = ( | |
encoder_out["padding_mask"] | |
if "padding_mask" in encoder_out | |
else encoder_out["encoder_padding_mask"] | |
) | |
if hasattr(model, "get_logits"): | |
emissions = model.get_logits(encoder_out, normalize=True) | |
else: | |
emissions = model.get_normalized_probs(encoder_out, log_probs=True) | |
return ( | |
emissions.cpu().float().transpose(0, 1), | |
padding.cpu() if padding is not None and padding.any() else None, | |
) | |
def decode_one(self, logits, padding): | |
from kaldi.matrix import Matrix | |
decoder = self.dec_cls(self.fst, self.decoder_options) | |
asr = self.rec_cls( | |
decoder, self.symbol_table, acoustic_scale=self.acoustic_scale | |
) | |
if padding is not None: | |
logits = logits[~padding] | |
mat = Matrix(logits.numpy()) | |
out = asr.decode(mat) | |
if self.nbest > 1: | |
from kaldi.fstext import shortestpath | |
from kaldi.fstext.utils import ( | |
convert_compact_lattice_to_lattice, | |
convert_lattice_to_std, | |
convert_nbest_to_list, | |
get_linear_symbol_sequence, | |
) | |
lat = out["lattice"] | |
sp = shortestpath(lat, nshortest=self.nbest) | |
sp = convert_compact_lattice_to_lattice(sp) | |
sp = convert_lattice_to_std(sp) | |
seq = convert_nbest_to_list(sp) | |
results = [] | |
for s in seq: | |
_, o, w = get_linear_symbol_sequence(s) | |
words = list(self.output_symbols[z] for z in o) | |
results.append( | |
{ | |
"tokens": words, | |
"words": words, | |
"score": w.value, | |
"emissions": logits, | |
} | |
) | |
return results | |
else: | |
words = out["text"].split() | |
return [ | |
{ | |
"tokens": words, | |
"words": words, | |
"score": out["likelihood"], | |
"emissions": logits, | |
} | |
] | |
def decode(self, emissions, padding): | |
if padding is None: | |
padding = [None] * len(emissions) | |
ret = list( | |
map( | |
lambda e, p: self.executor.submit(self.decode_one, e, p), | |
emissions, | |
padding, | |
) | |
) | |
return ret | |