File size: 3,746 Bytes
5917f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import hydra
from typing import List, Optional
from dataclasses import dataclass, field
import kenlm
from beam_search_utils import (
    SpeakerTaggingBeamSearchDecoder,
    load_input_jsons,
    load_reference_jsons,
    write_seglst_jsons,
    run_mp_beam_search_decoding,
    convert_nemo_json_to_seglst,
)
from hydra.core.config_store import ConfigStore

__INFO_TAG__ = "[INFO]"

@dataclass
class RealigningLanguageModelParameters:
    batch_size: int = 32
    use_mp: bool = True
    input_error_src_list_path: Optional[str] = None
    groundtruth_ref_list_path: Optional[str] = None
    arpa_language_model: Optional[str] = None
    word_window: int = 32
    port: List[int] = field(default_factory=list)
    parallel_chunk_word_len: int = 250
    use_ngram: bool = True
    peak_prob: float = 0.95
    alpha: float = 0.5
    beta: float = 0.05
    beam_width: int = 16
    out_dir: Optional[str] = None

cs = ConfigStore.instance()
cs.store(name="config", node=RealigningLanguageModelParameters)

@hydra.main(config_name="config", version_base="1.1")
def main(cfg: RealigningLanguageModelParameters) -> None:
    trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
    reference_info_dict  = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
    source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)
    loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
    
    speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
    
    div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict, 
                                                                    win_len=cfg.parallel_chunk_word_len, 
                                                                    word_window=cfg.word_window,
                                                                    port=cfg.port,)
    
    trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder, 
                                                    loaded_kenlm_model=loaded_kenlm_model,
                                                    trans_info_dict=div_trans_info_dict, 
                                                    org_trans_info_dict=trans_info_dict, 
                                                    div_mp=True,
                                                    win_len=cfg.parallel_chunk_word_len,
                                                    word_window=cfg.word_window,
                                                    port=cfg.port,
                                                    use_ngram=cfg.use_ngram,
                                                    )
    hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict) 
    
    write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
    write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
    write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
    print(f"{__INFO_TAG__} Parameters used: \
            \n ALPHA: {cfg.alpha} \
            \n BETA: {cfg.beta} \
            \n BEAM WIDTH: {cfg.beam_width} \
            \n Word Window: {cfg.word_window} \
            \n Use Ngram: {cfg.use_ngram} \
            \n Chunk Word Len: {cfg.parallel_chunk_word_len} \
            \n SpeakerLM Model: {cfg.arpa_language_model}") \

if __name__ == '__main__':
    main()