|
from tqdm import tqdm |
|
from typing import Dict, List |
|
from pydiardecode import build_diardecoder |
|
import numpy as np |
|
import copy |
|
import os |
|
import json |
|
import concurrent.futures |
|
import kenlm |
|
|
|
__INFO_TAG__ = "[BeamSearchUtil INFO]" |
|
|
|
class SpeakerTaggingBeamSearchDecoder: |
|
def __init__(self, loaded_kenlm_model: kenlm, cfg: dict): |
|
self.realigning_lm_params = cfg |
|
self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model) |
|
self._SPLITSYM = "@" |
|
|
|
def _load_realigning_LM(self, loaded_kenlm_model: kenlm): |
|
""" |
|
Load ARPA language model for realigning speaker labels for words. |
|
""" |
|
diar_decoder = build_diardecoder( |
|
loaded_kenlm_model=loaded_kenlm_model, |
|
kenlm_model_path=self.realigning_lm_params['arpa_language_model'], |
|
alpha=self.realigning_lm_params['alpha'], |
|
beta=self.realigning_lm_params['beta'], |
|
word_window=self.realigning_lm_params['word_window'], |
|
use_ngram=self.realigning_lm_params['use_ngram'], |
|
) |
|
return diar_decoder |
|
|
|
def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]: |
|
if speaker_count is None: |
|
spk_list = [] |
|
for k, line_dict in enumerate(word_dict_seq_list): |
|
_, spk_label = line_dict['word'], line_dict['speaker'] |
|
spk_list.append(spk_label) |
|
else: |
|
spk_list = [ f"speaker_{k}" for k in range(speaker_count)] |
|
|
|
realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'], |
|
speaker_list=sorted(list(set(spk_list))), |
|
word_dict_seq_list=word_dict_seq_list, |
|
port_num=port_num) |
|
return realigned_list |
|
|
|
def beam_search_diarization( |
|
self, |
|
trans_info_dict: Dict[str, Dict[str, list]], |
|
port_num: List[int] = None, |
|
) -> Dict[str, Dict[str, float]]: |
|
""" |
|
Match the diarization result with the ASR output. |
|
The words and the timestamps for the corresponding words are matched in a for loop. |
|
|
|
Args: |
|
|
|
Returns: |
|
trans_info_dict (dict): |
|
Dictionary containing word timestamps, speaker labels and words from all sessions. |
|
Each session is indexed by a unique ID. |
|
""" |
|
for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True): |
|
word_dict_seq_list = session_dict['words'] |
|
output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num) |
|
word_dict_seq_list = output_beams[0][2] |
|
trans_info_dict[uniq_id]['words'] = word_dict_seq_list |
|
return trans_info_dict |
|
|
|
def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16): |
|
""" |
|
Merge the outputs of parallel processing. |
|
""" |
|
uniq_id_list = list(org_trans_info_dict.keys()) |
|
sub_div_dict = {} |
|
for seq_id in div_trans_info_dict.keys(): |
|
div_info = seq_id.split(self._SPLITSYM) |
|
uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2]) |
|
if uniq_id not in sub_div_dict: |
|
sub_div_dict[uniq_id] = [None] * total_count |
|
sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words'] |
|
|
|
for uniq_id in uniq_id_list: |
|
org_trans_info_dict[uniq_id]['words'] = [] |
|
for k, div_words in enumerate(sub_div_dict[uniq_id]): |
|
if k == 0: |
|
div_words = div_words[:win_len] |
|
else: |
|
div_words = div_words[word_window:] |
|
org_trans_info_dict[uniq_id]['words'].extend(div_words) |
|
return org_trans_info_dict |
|
|
|
def divide_chunks(self, trans_info_dict, win_len, word_window, port): |
|
""" |
|
Divide word sequence into chunks of length `win_len` for parallel processing. |
|
|
|
Args: |
|
trans_info_dict (_type_): _description_ |
|
diar_logits (_type_): _description_ |
|
win_len (int, optional): _description_. Defaults to 250. |
|
""" |
|
if len(port) > 1: |
|
num_workers = len(port) |
|
else: |
|
num_workers = 1 |
|
div_trans_info_dict = {} |
|
for uniq_id in trans_info_dict.keys(): |
|
uniq_trans = trans_info_dict[uniq_id] |
|
del uniq_trans['status'] |
|
del uniq_trans['transcription'] |
|
del uniq_trans['sentences'] |
|
word_seq = uniq_trans['words'] |
|
|
|
div_word_seq = [] |
|
if win_len is None: |
|
win_len = int(np.ceil(len(word_seq)/num_workers)) |
|
n_chunks = int(np.ceil(len(word_seq)/win_len)) |
|
|
|
for k in range(n_chunks): |
|
div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len]) |
|
|
|
total_count = len(div_word_seq) |
|
for k, w_seq in enumerate(div_word_seq): |
|
seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}" |
|
div_trans_info_dict[seq_id] = dict(uniq_trans) |
|
div_trans_info_dict[seq_id]['words'] = w_seq |
|
return div_trans_info_dict |
|
|
|
def run_mp_beam_search_decoding( |
|
speaker_beam_search_decoder, |
|
loaded_kenlm_model, |
|
div_trans_info_dict, |
|
org_trans_info_dict, |
|
div_mp, |
|
win_len, |
|
word_window, |
|
port=None, |
|
use_ngram=False |
|
): |
|
if len(port) > 1: |
|
port = [int(p) for p in port] |
|
if use_ngram: |
|
port = [None] |
|
num_workers = 36 |
|
else: |
|
num_workers = len(port) |
|
|
|
uniq_id_list = sorted(list(div_trans_info_dict.keys() )) |
|
tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) |
|
futures = [] |
|
|
|
count = 0 |
|
for uniq_id in uniq_id_list: |
|
print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...") |
|
if port is not None: |
|
port_num = port[count % len(port)] |
|
else: |
|
port_num = None |
|
count += 1 |
|
uniq_trans_info_dict = {uniq_id: div_trans_info_dict[uniq_id]} |
|
futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num)) |
|
|
|
pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files") |
|
count = 0 |
|
output_trans_info_dict = {} |
|
for done_future in concurrent.futures.as_completed(futures): |
|
count += 1 |
|
pbar.update() |
|
output_trans_info_dict.update(done_future.result()) |
|
pbar.close() |
|
tp.shutdown() |
|
if div_mp: |
|
output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict, |
|
org_trans_info_dict=org_trans_info_dict, |
|
win_len=win_len, |
|
word_window=word_window) |
|
return output_trans_info_dict |
|
|
|
def count_num_of_spks(json_trans_list): |
|
spk_set = set() |
|
for sentence_dict in json_trans_list: |
|
spk_set.add(sentence_dict['speaker']) |
|
speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)} |
|
return speaker_map |
|
|
|
def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4): |
|
nemo_json_dict = {} |
|
word_dict_seq_list = [] |
|
if peak_prob > 1 or peak_prob < 0: |
|
raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}") |
|
speaker_map = count_num_of_spks(json_trans_list) |
|
base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1) |
|
stt_sec, end_sec = None, None |
|
for sentence_dict in json_trans_list: |
|
word_list = sentence_dict['words'].split() |
|
speaker = sentence_dict['speaker'] |
|
for word in word_list: |
|
speaker_softmax = copy.deepcopy(base_array) |
|
speaker_softmax[speaker_map[speaker]] = peak_prob |
|
word_dict_seq_list.append({'word': word, |
|
'start_time': stt_sec, |
|
'end_time': end_sec, |
|
'speaker': speaker_map[speaker], |
|
'speaker_softmax': speaker_softmax} |
|
) |
|
nemo_json_dict.update({'words': word_dict_seq_list, |
|
'status': "success", |
|
'sentences': json_trans_list, |
|
'speaker_count': len(speaker_map), |
|
'transcription': None} |
|
) |
|
return nemo_json_dict |
|
|
|
def convert_nemo_json_to_seglst(trans_info_dict): |
|
seglst_seq_list = [] |
|
seg_lst_dict, spk_wise_trans_sessions = {}, {} |
|
for uniq_id in trans_info_dict.keys(): |
|
spk_wise_trans_sessions[uniq_id] = {} |
|
seglst_seq_list = [] |
|
word_seq_list = trans_info_dict[uniq_id]['words'] |
|
prev_speaker, sentence = None, '' |
|
for widx, word_dict in enumerate(word_seq_list): |
|
curr_speaker = word_dict['speaker'] |
|
|
|
|
|
word = word_dict['word'] |
|
if curr_speaker not in spk_wise_trans_sessions[uniq_id]: |
|
spk_wise_trans_sessions[uniq_id][curr_speaker] = word |
|
elif curr_speaker in spk_wise_trans_sessions[uniq_id]: |
|
spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}" |
|
|
|
|
|
if curr_speaker!= prev_speaker and prev_speaker is not None: |
|
seglst_seq_list.append({'session_id': uniq_id, |
|
'words': sentence.strip(), |
|
'start_time': 0.0, |
|
'end_time': 0.0, |
|
'speaker': prev_speaker, |
|
}) |
|
sentence = word_dict['word'] |
|
else: |
|
sentence = f"{sentence} {word_dict['word']}" |
|
prev_speaker = curr_speaker |
|
|
|
|
|
|
|
|
|
if widx == len(word_seq_list) - 1: |
|
seglst_seq_list.append({'session_id': uniq_id, |
|
'words': sentence.strip(), |
|
'start_time': 0.0, |
|
'end_time': 0.0, |
|
'speaker': curr_speaker, |
|
}) |
|
seg_lst_dict[uniq_id] = seglst_seq_list |
|
return seg_lst_dict |
|
|
|
def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4): |
|
trans_info_dict = {} |
|
json_filepath_list = open(input_error_src_list_path).readlines() |
|
for json_path in json_filepath_list: |
|
json_path = json_path.strip() |
|
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0] |
|
if os.path.exists(json_path): |
|
with open(json_path, "r") as file: |
|
json_trans = json.load(file) |
|
else: |
|
raise FileNotFoundError(f"{json_path} does not exist. Aborting.") |
|
nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks) |
|
trans_info_dict[uniq_id] = nemo_json_dict |
|
return trans_info_dict |
|
|
|
def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"): |
|
reference_info_dict = {} |
|
json_filepath_list = open(reference_seglst_list_path).readlines() |
|
for json_path in json_filepath_list: |
|
json_path = json_path.strip() |
|
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0] |
|
if os.path.exists(json_path): |
|
with open(json_path, "r") as file: |
|
json_trans = json.load(file) |
|
else: |
|
raise FileNotFoundError(f"{json_path} does not exist. Aborting.") |
|
json_trans_uniq_id = [] |
|
for sentence_dict in json_trans: |
|
sentence_dict['session_id'] = uniq_id |
|
json_trans_uniq_id.append(sentence_dict) |
|
reference_info_dict[uniq_id] = json_trans_uniq_id |
|
return reference_info_dict |
|
|
|
def write_seglst_jsons( |
|
seg_lst_sessions_dict: dict, |
|
input_error_src_list_path: str, |
|
diar_out_path: str, |
|
ext_str: str, |
|
write_individual_seglst_jsons=True |
|
): |
|
""" |
|
Writes the segment list (seglst) JSON files to the output directory. |
|
|
|
Parameters: |
|
seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values. |
|
input_error_src_list_path (str): The path to the input error source list file. |
|
diar_out_path (str): The path to the output directory where the seglst JSON files will be written. |
|
type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference). |
|
write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True. |
|
|
|
Returns: |
|
None |
|
""" |
|
total_infer_list = [] |
|
total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "") |
|
for session_id, seg_lst_list in seg_lst_sessions_dict.items(): |
|
total_infer_list.extend(seg_lst_list) |
|
if write_individual_seglst_jsons: |
|
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json") |
|
with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file: |
|
json.dump(seg_lst_list, file, indent=4) |
|
|
|
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json") |
|
total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str) |
|
with open(f'{diar_out_path}/{total_output_filename}.seglst.json', 'w') as file: |
|
json.dump(total_infer_list, file, indent=4) |
|
|