SLT-Task2-ngram-baseline / speaker_tagging_beamsearch.py
Taejin's picture
Uploading ngram base model
5917f0a verified
raw
history blame
3.75 kB
import hydra
from typing import List, Optional
from dataclasses import dataclass, field
import kenlm
from beam_search_utils import (
SpeakerTaggingBeamSearchDecoder,
load_input_jsons,
load_reference_jsons,
write_seglst_jsons,
run_mp_beam_search_decoding,
convert_nemo_json_to_seglst,
)
from hydra.core.config_store import ConfigStore
__INFO_TAG__ = "[INFO]"
@dataclass
class RealigningLanguageModelParameters:
batch_size: int = 32
use_mp: bool = True
input_error_src_list_path: Optional[str] = None
groundtruth_ref_list_path: Optional[str] = None
arpa_language_model: Optional[str] = None
word_window: int = 32
port: List[int] = field(default_factory=list)
parallel_chunk_word_len: int = 250
use_ngram: bool = True
peak_prob: float = 0.95
alpha: float = 0.5
beta: float = 0.05
beam_width: int = 16
out_dir: Optional[str] = None
cs = ConfigStore.instance()
cs.store(name="config", node=RealigningLanguageModelParameters)
@hydra.main(config_name="config", version_base="1.1")
def main(cfg: RealigningLanguageModelParameters) -> None:
trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
reference_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict,
win_len=cfg.parallel_chunk_word_len,
word_window=cfg.word_window,
port=cfg.port,)
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
loaded_kenlm_model=loaded_kenlm_model,
trans_info_dict=div_trans_info_dict,
org_trans_info_dict=trans_info_dict,
div_mp=True,
win_len=cfg.parallel_chunk_word_len,
word_window=cfg.word_window,
port=cfg.port,
use_ngram=cfg.use_ngram,
)
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
print(f"{__INFO_TAG__} Parameters used: \
\n ALPHA: {cfg.alpha} \
\n BETA: {cfg.beta} \
\n BEAM WIDTH: {cfg.beam_width} \
\n Word Window: {cfg.word_window} \
\n Use Ngram: {cfg.use_ngram} \
\n Chunk Word Len: {cfg.parallel_chunk_word_len} \
\n SpeakerLM Model: {cfg.arpa_language_model}") \
if __name__ == '__main__':
main()