SLT-Task2-ngram-baseline / beam_search_utils.py
Taejin's picture
Uploading ngram base model
5917f0a verified
raw
history blame
14.5 kB
from tqdm import tqdm
from typing import Dict, List
from pydiardecode import build_diardecoder
import numpy as np
import copy
import os
import json
import concurrent.futures
import kenlm
__INFO_TAG__ = "[INFO]"
class SpeakerTaggingBeamSearchDecoder:
def __init__(self, loaded_kenlm_model: kenlm, cfg: dict):
self.realigning_lm_params = cfg
self.realigning_lm = self._load_realigning_LM(loaded_kenlm_model=loaded_kenlm_model)
self._SPLITSYM = "@"
def _load_realigning_LM(self, loaded_kenlm_model: kenlm):
"""
Load ARPA language model for realigning speaker labels for words.
"""
diar_decoder = build_diardecoder(
loaded_kenlm_model=loaded_kenlm_model,
kenlm_model_path=self.realigning_lm_params['arpa_language_model'],
alpha=self.realigning_lm_params['alpha'],
beta=self.realigning_lm_params['beta'],
word_window=self.realigning_lm_params['word_window'],
use_ngram=self.realigning_lm_params['use_ngram'],
)
return diar_decoder
def realign_words_with_lm(self, word_dict_seq_list: List[Dict[str, float]], speaker_count: int = None, port_num=None) -> List[Dict[str, float]]:
if speaker_count is None:
spk_list = []
for k, line_dict in enumerate(word_dict_seq_list):
_, spk_label = line_dict['word'], line_dict['speaker']
spk_list.append(spk_label)
else:
spk_list = [ f"speaker_{k}" for k in range(speaker_count)]
realigned_list = self.realigning_lm.decode_beams(beam_width=self.realigning_lm_params['beam_width'],
speaker_list=sorted(list(set(spk_list))),
word_dict_seq_list=word_dict_seq_list,
port_num=port_num)
return realigned_list
def beam_search_diarization(
self,
trans_info_dict: Dict[str, Dict[str, list]],
port_num: List[int] = None,
) -> Dict[str, Dict[str, float]]:
"""
Match the diarization result with the ASR output.
The words and the timestamps for the corresponding words are matched in a for loop.
Args:
Returns:
trans_info_dict (dict):
Dictionary containing word timestamps, speaker labels and words from all sessions.
Each session is indexed by a unique ID.
"""
for uniq_id, session_dict in tqdm(trans_info_dict.items(), total=len(trans_info_dict), disable=True):
word_dict_seq_list = session_dict['words']
output_beams = self.realign_words_with_lm(word_dict_seq_list=word_dict_seq_list, speaker_count=session_dict['speaker_count'], port_num=port_num)
word_dict_seq_list = output_beams[0][2]
trans_info_dict[uniq_id]['words'] = word_dict_seq_list
return trans_info_dict
def merge_div_inputs(self, div_trans_info_dict, org_trans_info_dict, win_len=250, word_window=16):
"""
Merge the outputs of parallel processing.
"""
uniq_id_list = list(org_trans_info_dict.keys())
sub_div_dict = {}
for seq_id in div_trans_info_dict.keys():
div_info = seq_id.split(self._SPLITSYM)
uniq_id, sub_idx, total_count = div_info[0], int(div_info[1]), int(div_info[2])
if uniq_id not in sub_div_dict:
sub_div_dict[uniq_id] = [None] * total_count
sub_div_dict[uniq_id][sub_idx] = div_trans_info_dict[seq_id]['words']
for uniq_id in uniq_id_list:
org_trans_info_dict[uniq_id]['words'] = []
for k, div_words in enumerate(sub_div_dict[uniq_id]):
if k == 0:
div_words = div_words[:win_len]
else:
div_words = div_words[word_window:]
org_trans_info_dict[uniq_id]['words'].extend(div_words)
return org_trans_info_dict
def divide_chunks(self, trans_info_dict, win_len, word_window, port):
"""
Divide word sequence into chunks of length `win_len` for parallel processing.
Args:
trans_info_dict (_type_): _description_
diar_logits (_type_): _description_
win_len (int, optional): _description_. Defaults to 250.
"""
if len(port) > 1:
num_workers = len(port)
else:
num_workers = 1
div_trans_info_dict = {}
for uniq_id in trans_info_dict.keys():
uniq_trans = trans_info_dict[uniq_id]
del uniq_trans['status']
del uniq_trans['transcription']
del uniq_trans['sentences']
word_seq = uniq_trans['words']
div_word_seq = []
if win_len is None:
win_len = int(np.ceil(len(word_seq)/num_workers))
n_chunks = int(np.ceil(len(word_seq)/win_len))
for k in range(n_chunks):
div_word_seq.append(word_seq[max(k*win_len - word_window, 0):(k+1)*win_len])
total_count = len(div_word_seq)
for k, w_seq in enumerate(div_word_seq):
seq_id = uniq_id + f"{self._SPLITSYM}{k}{self._SPLITSYM}{total_count}"
div_trans_info_dict[seq_id] = dict(uniq_trans)
div_trans_info_dict[seq_id]['words'] = w_seq
return div_trans_info_dict
def run_mp_beam_search_decoding(
speaker_beam_search_decoder,
loaded_kenlm_model,
trans_info_dict,
org_trans_info_dict,
div_mp,
win_len,
word_window,
port=None,
use_ngram=False
):
if len(port) > 1:
port = [int(p) for p in port]
if use_ngram:
port = [None]
num_workers = 36
else:
num_workers = len(port)
uniq_id_list = sorted(list(trans_info_dict.keys() ))
tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers)
futures = []
count = 0
for uniq_id in uniq_id_list:
print(f"{__INFO_TAG__} Running beam search decoding for {uniq_id}...")
if port is not None:
port_num = port[count % len(port)]
else:
port_num = None
count += 1
uniq_trans_info_dict = {uniq_id: trans_info_dict[uniq_id]}
futures.append(tp.submit(speaker_beam_search_decoder.beam_search_diarization, uniq_trans_info_dict, port_num=port_num))
pbar = tqdm(total=len(uniq_id_list), desc="Running beam search decoding", unit="files")
count = 0
output_trans_info_dict = {}
for done_future in concurrent.futures.as_completed(futures):
count += 1
pbar.update()
output_trans_info_dict.update(done_future.result())
pbar.close()
tp.shutdown()
if div_mp:
output_trans_info_dict = speaker_beam_search_decoder.merge_div_inputs(div_trans_info_dict=output_trans_info_dict,
org_trans_info_dict=org_trans_info_dict,
win_len=win_len,
word_window=word_window)
return output_trans_info_dict
def count_num_of_spks(json_trans_list):
spk_set = set()
for sentence_dict in json_trans_list:
spk_set.add(sentence_dict['speaker'])
speaker_map = { spk_str: idx for idx, spk_str in enumerate(spk_set)}
return speaker_map
def add_placeholder_speaker_softmax(json_trans_list, peak_prob=0.94 ,max_spks=4):
nemo_json_dict = {}
word_dict_seq_list = []
if peak_prob > 1 or peak_prob < 0:
raise ValueError(f"peak_prob must be between 0 and 1 but got {peak_prob}")
speaker_map = count_num_of_spks(json_trans_list)
base_array = np.ones(max_spks) * (1 - peak_prob)/(max_spks-1)
stt_sec, end_sec = None, None
for sentence_dict in json_trans_list:
word_list = sentence_dict['words'].split()
speaker = sentence_dict['speaker']
for word in word_list:
speaker_softmax = copy.deepcopy(base_array)
speaker_softmax[speaker_map[speaker]] = peak_prob
word_dict_seq_list.append({'word': word,
'start_time': stt_sec,
'end_time': end_sec,
'speaker': speaker_map[speaker],
'speaker_softmax': speaker_softmax}
)
nemo_json_dict.update({'words': word_dict_seq_list,
'status': "success",
'sentences': json_trans_list,
'speaker_count': len(speaker_map),
'transcription': None}
)
return nemo_json_dict
def convert_nemo_json_to_seglst(trans_info_dict):
seglst_seq_list = []
seg_lst_dict, spk_wise_trans_sessions = {}, {}
for uniq_id in trans_info_dict.keys():
spk_wise_trans_sessions[uniq_id] = {}
seglst_seq_list = []
word_seq_list = trans_info_dict[uniq_id]['words']
prev_speaker, sentence = None, ''
for widx, word_dict in enumerate(word_seq_list):
curr_speaker = word_dict['speaker']
# For making speaker wise transcriptions
word = word_dict['word']
if curr_speaker not in spk_wise_trans_sessions[uniq_id]:
spk_wise_trans_sessions[uniq_id][curr_speaker] = word
elif curr_speaker in spk_wise_trans_sessions[uniq_id]:
spk_wise_trans_sessions[uniq_id][curr_speaker] = f"{spk_wise_trans_sessions[uniq_id][curr_speaker]} {word_dict['word']}"
# For making segment wise transcriptions
if curr_speaker!= prev_speaker and prev_speaker is not None:
seglst_seq_list.append({'session_id': uniq_id,
'words': sentence.strip(),
'start_time': 0.0,
'end_time': 0.0,
'speaker': prev_speaker,
})
sentence = word_dict['word']
else:
sentence = f"{sentence} {word_dict['word']}"
prev_speaker = curr_speaker
# For the last word:
# (1) If there is no speaker change, add the existing sentence and exit the loop
# (2) If there is a speaker change, add the last word and exit the loop
if widx == len(word_seq_list) - 1:
seglst_seq_list.append({'session_id': uniq_id,
'words': sentence.strip(),
'start_time': 0.0,
'end_time': 0.0,
'speaker': curr_speaker,
})
seg_lst_dict[uniq_id] = seglst_seq_list
return seg_lst_dict
def load_input_jsons(input_error_src_list_path, ext_str=".seglst.json", peak_prob=0.94, max_spks=4):
trans_info_dict = {}
json_filepath_list = open(input_error_src_list_path).readlines()
for json_path in json_filepath_list:
json_path = json_path.strip()
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
if os.path.exists(json_path):
with open(json_path, "r") as file:
json_trans = json.load(file)
else:
raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
nemo_json_dict = add_placeholder_speaker_softmax(json_trans, peak_prob=peak_prob, max_spks=max_spks)
trans_info_dict[uniq_id] = nemo_json_dict
return trans_info_dict
def load_reference_jsons(reference_seglst_list_path, ext_str=".seglst.json"):
reference_info_dict = {}
json_filepath_list = open(reference_seglst_list_path).readlines()
for json_path in json_filepath_list:
json_path = json_path.strip()
uniq_id = os.path.split(json_path)[-1].split(ext_str)[0]
if os.path.exists(json_path):
with open(json_path, "r") as file:
json_trans = json.load(file)
else:
raise FileNotFoundError(f"{json_path} does not exist. Aborting.")
json_trans_uniq_id = []
for sentence_dict in json_trans:
sentence_dict['session_id'] = uniq_id
json_trans_uniq_id.append(sentence_dict)
reference_info_dict[uniq_id] = json_trans_uniq_id
return reference_info_dict
def write_seglst_jsons(
seg_lst_sessions_dict: dict,
input_error_src_list_path: str,
diar_out_path: str,
ext_str: str,
write_individual_seglst_jsons=True
):
"""
Writes the segment list (seglst) JSON files to the output directory.
Parameters:
seg_lst_sessions_dict (dict): A dictionary containing session IDs as keys and their corresponding segment lists as values.
input_error_src_list_path (str): The path to the input error source list file.
diar_out_path (str): The path to the output directory where the seglst JSON files will be written.
type_string (str): A string representing the type of the seglst JSON files (e.g., 'hyp' for hypothesis or 'ef' for reference).
write_individual_seglst_jsons (bool, optional): A flag indicating whether to write individual seglst JSON files for each session. Defaults to True.
Returns:
None
"""
total_infer_list = []
total_output_filename = os.path.split(input_error_src_list_path)[-1].replace(".list", "")
for session_id, seg_lst_list in seg_lst_sessions_dict.items():
total_infer_list.extend(seg_lst_list)
if write_individual_seglst_jsons:
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
with open(f'{diar_out_path}/{session_id}.seglst.json', 'w') as file:
json.dump(seg_lst_list, file, indent=4) # indent=4 for pretty printing
print(f"{__INFO_TAG__} Writing {diar_out_path}/{session_id}.seglst.json")
total_output_filename = total_output_filename.replace("src", ext_str).replace("ref", ext_str)
with open(f'{diar_out_path}/../{total_output_filename}.seglst.json', 'w') as file:
json.dump(total_infer_list, file, indent=4) # indent=4 for pretty printing