Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
import hydra | |
from hydra.core.config_store import ConfigStore | |
import logging | |
from omegaconf import MISSING, OmegaConf | |
import os | |
import os.path as osp | |
from pathlib import Path | |
import subprocess | |
from typing import Optional | |
from fairseq.data.dictionary import Dictionary | |
from fairseq.dataclass import FairseqDataclass | |
script_dir = Path(__file__).resolve().parent | |
config_path = script_dir / "config" | |
logger = logging.getLogger(__name__) | |
class KaldiInitializerConfig(FairseqDataclass): | |
data_dir: str = MISSING | |
fst_dir: Optional[str] = None | |
in_labels: str = MISSING | |
out_labels: Optional[str] = None | |
wav2letter_lexicon: Optional[str] = None | |
lm_arpa: str = MISSING | |
kaldi_root: str = MISSING | |
blank_symbol: str = "<s>" | |
silence_symbol: Optional[str] = None | |
def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path: | |
in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt" | |
if not in_units_file.exists(): | |
logger.info(f"Creating {in_units_file}") | |
with open(in_units_file, "w") as f: | |
print("<eps> 0", file=f) | |
i = 1 | |
for symb in vocab.symbols[vocab.nspecial :]: | |
if not symb.startswith("madeupword"): | |
print(f"{symb} {i}", file=f) | |
i += 1 | |
return in_units_file | |
def create_lexicon( | |
cfg: KaldiInitializerConfig, | |
fst_dir: Path, | |
unique_label: str, | |
in_units_file: Path, | |
out_words_file: Path, | |
) -> (Path, Path): | |
disambig_in_units_file = fst_dir / f"kaldi_dict.{cfg.in_labels}_disambig.txt" | |
lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}.txt" | |
disambig_lexicon_file = fst_dir / f"kaldi_lexicon.{unique_label}_disambig.txt" | |
if ( | |
not lexicon_file.exists() | |
or not disambig_lexicon_file.exists() | |
or not disambig_in_units_file.exists() | |
): | |
logger.info(f"Creating {lexicon_file} (in units file: {in_units_file})") | |
assert cfg.wav2letter_lexicon is not None or cfg.in_labels == cfg.out_labels | |
if cfg.wav2letter_lexicon is not None: | |
lm_words = set() | |
with open(out_words_file, "r") as lm_dict_f: | |
for line in lm_dict_f: | |
lm_words.add(line.split()[0]) | |
num_skipped = 0 | |
total = 0 | |
with open(cfg.wav2letter_lexicon, "r") as w2l_lex_f, open( | |
lexicon_file, "w" | |
) as out_f: | |
for line in w2l_lex_f: | |
items = line.rstrip().split("\t") | |
assert len(items) == 2, items | |
if items[0] in lm_words: | |
print(items[0], items[1], file=out_f) | |
else: | |
num_skipped += 1 | |
logger.debug( | |
f"Skipping word {items[0]} as it was not found in LM" | |
) | |
total += 1 | |
if num_skipped > 0: | |
logger.warning( | |
f"Skipped {num_skipped} out of {total} words as they were not found in LM" | |
) | |
else: | |
with open(in_units_file, "r") as in_f, open(lexicon_file, "w") as out_f: | |
for line in in_f: | |
symb = line.split()[0] | |
if symb != "<eps>" and symb != "<ctc_blank>" and symb != "<SIL>": | |
print(symb, symb, file=out_f) | |
lex_disambig_path = ( | |
Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_lex_disambig.pl" | |
) | |
res = subprocess.run( | |
[lex_disambig_path, lexicon_file, disambig_lexicon_file], | |
check=True, | |
capture_output=True, | |
) | |
ndisambig = int(res.stdout) | |
disamib_path = Path(cfg.kaldi_root) / "egs/wsj/s5/utils/add_disambig.pl" | |
res = subprocess.run( | |
[disamib_path, "--include-zero", in_units_file, str(ndisambig)], | |
check=True, | |
capture_output=True, | |
) | |
with open(disambig_in_units_file, "wb") as f: | |
f.write(res.stdout) | |
return disambig_lexicon_file, disambig_in_units_file | |
def create_G( | |
kaldi_root: Path, fst_dir: Path, lm_arpa: Path, arpa_base: str | |
) -> (Path, Path): | |
out_words_file = fst_dir / f"kaldi_dict.{arpa_base}.txt" | |
grammar_graph = fst_dir / f"G_{arpa_base}.fst" | |
if not grammar_graph.exists() or not out_words_file.exists(): | |
logger.info(f"Creating {grammar_graph}") | |
arpa2fst = kaldi_root / "src/lmbin/arpa2fst" | |
subprocess.run( | |
[ | |
arpa2fst, | |
"--disambig-symbol=#0", | |
f"--write-symbol-table={out_words_file}", | |
lm_arpa, | |
grammar_graph, | |
], | |
check=True, | |
) | |
return grammar_graph, out_words_file | |
def create_L( | |
kaldi_root: Path, | |
fst_dir: Path, | |
unique_label: str, | |
lexicon_file: Path, | |
in_units_file: Path, | |
out_words_file: Path, | |
) -> Path: | |
lexicon_graph = fst_dir / f"L.{unique_label}.fst" | |
if not lexicon_graph.exists(): | |
logger.info(f"Creating {lexicon_graph} (in units: {in_units_file})") | |
make_lex = kaldi_root / "egs/wsj/s5/utils/make_lexicon_fst.pl" | |
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" | |
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" | |
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" | |
def write_disambig_symbol(file): | |
with open(file, "r") as f: | |
for line in f: | |
items = line.rstrip().split() | |
if items[0] == "#0": | |
out_path = str(file) + "_disamig" | |
with open(out_path, "w") as out_f: | |
print(items[1], file=out_f) | |
return out_path | |
return None | |
in_disambig_sym = write_disambig_symbol(in_units_file) | |
assert in_disambig_sym is not None | |
out_disambig_sym = write_disambig_symbol(out_words_file) | |
assert out_disambig_sym is not None | |
try: | |
with open(lexicon_graph, "wb") as out_f: | |
res = subprocess.run( | |
[make_lex, lexicon_file], capture_output=True, check=True | |
) | |
assert len(res.stderr) == 0, res.stderr.decode("utf-8") | |
res = subprocess.run( | |
[ | |
fstcompile, | |
f"--isymbols={in_units_file}", | |
f"--osymbols={out_words_file}", | |
"--keep_isymbols=false", | |
"--keep_osymbols=false", | |
], | |
input=res.stdout, | |
capture_output=True, | |
) | |
assert len(res.stderr) == 0, res.stderr.decode("utf-8") | |
res = subprocess.run( | |
[fstaddselfloops, in_disambig_sym, out_disambig_sym], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstarcsort, "--sort_type=olabel"], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
out_f.write(res.stdout) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
os.remove(lexicon_graph) | |
raise | |
except AssertionError: | |
os.remove(lexicon_graph) | |
raise | |
return lexicon_graph | |
def create_LG( | |
kaldi_root: Path, | |
fst_dir: Path, | |
unique_label: str, | |
lexicon_graph: Path, | |
grammar_graph: Path, | |
) -> Path: | |
lg_graph = fst_dir / f"LG.{unique_label}.fst" | |
if not lg_graph.exists(): | |
logger.info(f"Creating {lg_graph}") | |
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" | |
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" | |
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" | |
fstpushspecial = kaldi_root / "src/fstbin/fstpushspecial" | |
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" | |
try: | |
with open(lg_graph, "wb") as out_f: | |
res = subprocess.run( | |
[fsttablecompose, lexicon_graph, grammar_graph], | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[ | |
fstdeterminizestar, | |
"--use-log=true", | |
], | |
input=res.stdout, | |
capture_output=True, | |
) | |
res = subprocess.run( | |
[fstminimizeencoded], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstpushspecial], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstarcsort, "--sort_type=ilabel"], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
out_f.write(res.stdout) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
os.remove(lg_graph) | |
raise | |
return lg_graph | |
def create_H( | |
kaldi_root: Path, | |
fst_dir: Path, | |
disambig_out_units_file: Path, | |
in_labels: str, | |
vocab: Dictionary, | |
blk_sym: str, | |
silence_symbol: Optional[str], | |
) -> (Path, Path, Path): | |
h_graph = ( | |
fst_dir / f"H.{in_labels}{'_' + silence_symbol if silence_symbol else ''}.fst" | |
) | |
h_out_units_file = fst_dir / f"kaldi_dict.h_out.{in_labels}.txt" | |
disambig_in_units_file_int = Path(str(h_graph) + "isym_disambig.int") | |
disambig_out_units_file_int = Path(str(disambig_out_units_file) + ".int") | |
if ( | |
not h_graph.exists() | |
or not h_out_units_file.exists() | |
or not disambig_in_units_file_int.exists() | |
): | |
logger.info(f"Creating {h_graph}") | |
eps_sym = "<eps>" | |
num_disambig = 0 | |
osymbols = [] | |
with open(disambig_out_units_file, "r") as f, open( | |
disambig_out_units_file_int, "w" | |
) as out_f: | |
for line in f: | |
symb, id = line.rstrip().split() | |
if line.startswith("#"): | |
num_disambig += 1 | |
print(id, file=out_f) | |
else: | |
if len(osymbols) == 0: | |
assert symb == eps_sym, symb | |
osymbols.append((symb, id)) | |
i_idx = 0 | |
isymbols = [(eps_sym, 0)] | |
imap = {} | |
for i, s in enumerate(vocab.symbols): | |
i_idx += 1 | |
isymbols.append((s, i_idx)) | |
imap[s] = i_idx | |
fst_str = [] | |
node_idx = 0 | |
root_node = node_idx | |
special_symbols = [blk_sym] | |
if silence_symbol is not None: | |
special_symbols.append(silence_symbol) | |
for ss in special_symbols: | |
fst_str.append("{} {} {} {}".format(root_node, root_node, ss, eps_sym)) | |
for symbol, _ in osymbols: | |
if symbol == eps_sym or symbol.startswith("#"): | |
continue | |
node_idx += 1 | |
# 1. from root to emitting state | |
fst_str.append("{} {} {} {}".format(root_node, node_idx, symbol, symbol)) | |
# 2. from emitting state back to root | |
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) | |
# 3. from emitting state to optional blank state | |
pre_node = node_idx | |
node_idx += 1 | |
for ss in special_symbols: | |
fst_str.append("{} {} {} {}".format(pre_node, node_idx, ss, eps_sym)) | |
# 4. from blank state back to root | |
fst_str.append("{} {} {} {}".format(node_idx, root_node, eps_sym, eps_sym)) | |
fst_str.append("{}".format(root_node)) | |
fst_str = "\n".join(fst_str) | |
h_str = str(h_graph) | |
isym_file = h_str + ".isym" | |
with open(isym_file, "w") as f: | |
for sym, id in isymbols: | |
f.write("{} {}\n".format(sym, id)) | |
with open(h_out_units_file, "w") as f: | |
for sym, id in osymbols: | |
f.write("{} {}\n".format(sym, id)) | |
with open(disambig_in_units_file_int, "w") as f: | |
disam_sym_id = len(isymbols) | |
for _ in range(num_disambig): | |
f.write("{}\n".format(disam_sym_id)) | |
disam_sym_id += 1 | |
fstcompile = kaldi_root / "tools/openfst-1.6.7/bin/fstcompile" | |
fstaddselfloops = kaldi_root / "src/fstbin/fstaddselfloops" | |
fstarcsort = kaldi_root / "tools/openfst-1.6.7/bin/fstarcsort" | |
try: | |
with open(h_graph, "wb") as out_f: | |
res = subprocess.run( | |
[ | |
fstcompile, | |
f"--isymbols={isym_file}", | |
f"--osymbols={h_out_units_file}", | |
"--keep_isymbols=false", | |
"--keep_osymbols=false", | |
], | |
input=str.encode(fst_str), | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[ | |
fstaddselfloops, | |
disambig_in_units_file_int, | |
disambig_out_units_file_int, | |
], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstarcsort, "--sort_type=olabel"], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
out_f.write(res.stdout) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
os.remove(h_graph) | |
raise | |
return h_graph, h_out_units_file, disambig_in_units_file_int | |
def create_HLGa( | |
kaldi_root: Path, | |
fst_dir: Path, | |
unique_label: str, | |
h_graph: Path, | |
lg_graph: Path, | |
disambig_in_words_file_int: Path, | |
) -> Path: | |
hlga_graph = fst_dir / f"HLGa.{unique_label}.fst" | |
if not hlga_graph.exists(): | |
logger.info(f"Creating {hlga_graph}") | |
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" | |
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" | |
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" | |
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" | |
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" | |
try: | |
with open(hlga_graph, "wb") as out_f: | |
res = subprocess.run( | |
[ | |
fsttablecompose, | |
h_graph, | |
lg_graph, | |
], | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstdeterminizestar, "--use-log=true"], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstrmsymbols, disambig_in_words_file_int], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstrmepslocal], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstminimizeencoded], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
out_f.write(res.stdout) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
os.remove(hlga_graph) | |
raise | |
return hlga_graph | |
def create_HLa( | |
kaldi_root: Path, | |
fst_dir: Path, | |
unique_label: str, | |
h_graph: Path, | |
l_graph: Path, | |
disambig_in_words_file_int: Path, | |
) -> Path: | |
hla_graph = fst_dir / f"HLa.{unique_label}.fst" | |
if not hla_graph.exists(): | |
logger.info(f"Creating {hla_graph}") | |
fsttablecompose = kaldi_root / "src/fstbin/fsttablecompose" | |
fstdeterminizestar = kaldi_root / "src/fstbin/fstdeterminizestar" | |
fstrmsymbols = kaldi_root / "src/fstbin/fstrmsymbols" | |
fstrmepslocal = kaldi_root / "src/fstbin/fstrmepslocal" | |
fstminimizeencoded = kaldi_root / "src/fstbin/fstminimizeencoded" | |
try: | |
with open(hla_graph, "wb") as out_f: | |
res = subprocess.run( | |
[ | |
fsttablecompose, | |
h_graph, | |
l_graph, | |
], | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstdeterminizestar, "--use-log=true"], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstrmsymbols, disambig_in_words_file_int], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstrmepslocal], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
res = subprocess.run( | |
[fstminimizeencoded], | |
input=res.stdout, | |
capture_output=True, | |
check=True, | |
) | |
out_f.write(res.stdout) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
os.remove(hla_graph) | |
raise | |
return hla_graph | |
def create_HLG( | |
kaldi_root: Path, | |
fst_dir: Path, | |
unique_label: str, | |
hlga_graph: Path, | |
prefix: str = "HLG", | |
) -> Path: | |
hlg_graph = fst_dir / f"{prefix}.{unique_label}.fst" | |
if not hlg_graph.exists(): | |
logger.info(f"Creating {hlg_graph}") | |
add_self_loop = script_dir / "add-self-loop-simple" | |
kaldi_src = kaldi_root / "src" | |
kaldi_lib = kaldi_src / "lib" | |
try: | |
if not add_self_loop.exists(): | |
fst_include = kaldi_root / "tools/openfst-1.6.7/include" | |
add_self_loop_src = script_dir / "add-self-loop-simple.cc" | |
subprocess.run( | |
[ | |
"c++", | |
f"-I{kaldi_src}", | |
f"-I{fst_include}", | |
f"-L{kaldi_lib}", | |
add_self_loop_src, | |
"-lkaldi-base", | |
"-lkaldi-fstext", | |
"-o", | |
add_self_loop, | |
], | |
check=True, | |
) | |
my_env = os.environ.copy() | |
my_env["LD_LIBRARY_PATH"] = f"{kaldi_lib}:{my_env['LD_LIBRARY_PATH']}" | |
subprocess.run( | |
[ | |
add_self_loop, | |
hlga_graph, | |
hlg_graph, | |
], | |
check=True, | |
capture_output=True, | |
env=my_env, | |
) | |
except subprocess.CalledProcessError as e: | |
logger.error(f"cmd: {e.cmd}, err: {e.stderr.decode('utf-8')}") | |
raise | |
return hlg_graph | |
def initalize_kaldi(cfg: KaldiInitializerConfig) -> Path: | |
if cfg.fst_dir is None: | |
cfg.fst_dir = osp.join(cfg.data_dir, "kaldi") | |
if cfg.out_labels is None: | |
cfg.out_labels = cfg.in_labels | |
kaldi_root = Path(cfg.kaldi_root) | |
data_dir = Path(cfg.data_dir) | |
fst_dir = Path(cfg.fst_dir) | |
fst_dir.mkdir(parents=True, exist_ok=True) | |
arpa_base = osp.splitext(osp.basename(cfg.lm_arpa))[0] | |
unique_label = f"{cfg.in_labels}.{arpa_base}" | |
with open(data_dir / f"dict.{cfg.in_labels}.txt", "r") as f: | |
vocab = Dictionary.load(f) | |
in_units_file = create_units(fst_dir, cfg.in_labels, vocab) | |
grammar_graph, out_words_file = create_G( | |
kaldi_root, fst_dir, Path(cfg.lm_arpa), arpa_base | |
) | |
disambig_lexicon_file, disambig_L_in_units_file = create_lexicon( | |
cfg, fst_dir, unique_label, in_units_file, out_words_file | |
) | |
h_graph, h_out_units_file, disambig_in_units_file_int = create_H( | |
kaldi_root, | |
fst_dir, | |
disambig_L_in_units_file, | |
cfg.in_labels, | |
vocab, | |
cfg.blank_symbol, | |
cfg.silence_symbol, | |
) | |
lexicon_graph = create_L( | |
kaldi_root, | |
fst_dir, | |
unique_label, | |
disambig_lexicon_file, | |
disambig_L_in_units_file, | |
out_words_file, | |
) | |
lg_graph = create_LG( | |
kaldi_root, fst_dir, unique_label, lexicon_graph, grammar_graph | |
) | |
hlga_graph = create_HLGa( | |
kaldi_root, fst_dir, unique_label, h_graph, lg_graph, disambig_in_units_file_int | |
) | |
hlg_graph = create_HLG(kaldi_root, fst_dir, unique_label, hlga_graph) | |
# for debugging | |
# hla_graph = create_HLa(kaldi_root, fst_dir, unique_label, h_graph, lexicon_graph, disambig_in_units_file_int) | |
# hl_graph = create_HLG(kaldi_root, fst_dir, unique_label, hla_graph, prefix="HL_looped") | |
# create_HLG(kaldi_root, fst_dir, "phnc", h_graph, prefix="H_looped") | |
return hlg_graph | |
def cli_main(cfg: KaldiInitializerConfig) -> None: | |
container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) | |
cfg = OmegaConf.create(container) | |
OmegaConf.set_struct(cfg, True) | |
initalize_kaldi(cfg) | |
if __name__ == "__main__": | |
logging.root.setLevel(logging.INFO) | |
logging.basicConfig(level=logging.INFO) | |
try: | |
from hydra._internal.utils import ( | |
get_args, | |
) # pylint: disable=import-outside-toplevel | |
cfg_name = get_args().config_name or "kaldi_initializer" | |
except ImportError: | |
logger.warning("Failed to get config name from hydra args") | |
cfg_name = "kaldi_initializer" | |
cs = ConfigStore.instance() | |
cs.store(name=cfg_name, node=KaldiInitializerConfig) | |
cli_main() | |