johann22's picture
Upload 13 files
84d670e verified
import json
import logging
import os
import shutil
import nltk
import wget
from omegaconf import OmegaConf
from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE
punct_model_langs = [
"en",
"fr",
"de",
"es",
"it",
"nl",
"pt",
"bg",
"pl",
"cs",
"sk",
"sl",
]
wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
DEFAULT_ALIGN_MODELS_HF.keys()
)
whisper_langs = sorted(LANGUAGES.keys()) + sorted(
[k.title() for k in TO_LANGUAGE_CODE.keys()]
)
langs_to_iso = {
"aa": "aar",
"ab": "abk",
"ae": "ave",
"af": "afr",
"ak": "aka",
"am": "amh",
"an": "arg",
"ar": "ara",
"as": "asm",
"av": "ava",
"ay": "aym",
"az": "aze",
"ba": "bak",
"be": "bel",
"bg": "bul",
"bh": "bih",
"bi": "bis",
"bm": "bam",
"bn": "ben",
"bo": "tib",
"br": "bre",
"bs": "bos",
"ca": "cat",
"ce": "che",
"ch": "cha",
"co": "cos",
"cr": "cre",
"cs": "cze",
"cu": "chu",
"cv": "chv",
"cy": "wel",
"da": "dan",
"de": "ger",
"dv": "div",
"dz": "dzo",
"ee": "ewe",
"el": "gre",
"en": "eng",
"eo": "epo",
"es": "spa",
"et": "est",
"eu": "baq",
"fa": "per",
"ff": "ful",
"fi": "fin",
"fj": "fij",
"fo": "fao",
"fr": "fre",
"fy": "fry",
"ga": "gle",
"gd": "gla",
"gl": "glg",
"gn": "grn",
"gu": "guj",
"gv": "glv",
"ha": "hau",
"he": "heb",
"hi": "hin",
"ho": "hmo",
"hr": "hrv",
"ht": "hat",
"hu": "hun",
"hy": "arm",
"hz": "her",
"ia": "ina",
"id": "ind",
"ie": "ile",
"ig": "ibo",
"ii": "iii",
"ik": "ipk",
"io": "ido",
"is": "ice",
"it": "ita",
"iu": "iku",
"ja": "jpn",
"jv": "jav",
"ka": "geo",
"kg": "kon",
"ki": "kik",
"kj": "kua",
"kk": "kaz",
"kl": "kal",
"km": "khm",
"kn": "kan",
"ko": "kor",
"kr": "kau",
"ks": "kas",
"ku": "kur",
"kv": "kom",
"kw": "cor",
"ky": "kir",
"la": "lat",
"lb": "ltz",
"lg": "lug",
"li": "lim",
"ln": "lin",
"lo": "lao",
"lt": "lit",
"lu": "lub",
"lv": "lav",
"mg": "mlg",
"mh": "mah",
"mi": "mao",
"mk": "mac",
"ml": "mal",
"mn": "mon",
"mr": "mar",
"ms": "may",
"mt": "mlt",
"my": "bur",
"na": "nau",
"nb": "nob",
"nd": "nde",
"ne": "nep",
"ng": "ndo",
"nl": "dut",
"nn": "nno",
"no": "nor",
"nr": "nbl",
"nv": "nav",
"ny": "nya",
"oc": "oci",
"oj": "oji",
"om": "orm",
"or": "ori",
"os": "oss",
"pa": "pan",
"pi": "pli",
"pl": "pol",
"ps": "pus",
"pt": "por",
"qu": "que",
"rm": "roh",
"rn": "run",
"ro": "rum",
"ru": "rus",
"rw": "kin",
"sa": "san",
"sc": "srd",
"sd": "snd",
"se": "sme",
"sg": "sag",
"si": "sin",
"sk": "slo",
"sl": "slv",
"sm": "smo",
"sn": "sna",
"so": "som",
"sq": "alb",
"sr": "srp",
"ss": "ssw",
"st": "sot",
"su": "sun",
"sv": "swe",
"sw": "swa",
"ta": "tam",
"te": "tel",
"tg": "tgk",
"th": "tha",
"ti": "tir",
"tk": "tuk",
"tl": "tgl",
"tn": "tsn",
"to": "ton",
"tr": "tur",
"ts": "tso",
"tt": "tat",
"tw": "twi",
"ty": "tah",
"ug": "uig",
"uk": "ukr",
"ur": "urd",
"uz": "uzb",
"ve": "ven",
"vi": "vie",
"vo": "vol",
"wa": "wln",
"wo": "wol",
"xh": "xho",
"yi": "yid",
"yo": "yor",
"za": "zha",
"zh": "chi",
"zu": "zul",
}
def create_config(output_dir):
DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file
CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs"
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME)
if not os.path.exists(MODEL_CONFIG_PATH):
os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True)
CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH)
config = OmegaConf.load(MODEL_CONFIG_PATH)
data_dir = os.path.join(output_dir, "data")
os.makedirs(data_dir, exist_ok=True)
meta = {
"audio_filepath": os.path.join(output_dir, "mono_file.wav"),
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"rttm_filepath": None,
"uem_filepath": None,
}
with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
json.dump(meta, fp)
fp.write("\n")
pretrained_vad = "vad_multilingual_marblenet"
pretrained_speaker_model = "titanet_large"
config.num_workers = 0
config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
config.diarizer.out_dir = (
output_dir # Directory to store intermediate files and prediction outputs
)
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.oracle_vad = (
False # compute VAD provided with model_path to vad config
)
config.diarizer.clustering.parameters.oracle_num_speakers = False
# Here, we use our in-house pretrained NeMo VAD model
config.diarizer.vad.model_path = pretrained_vad
config.diarizer.vad.parameters.onset = 0.8
config.diarizer.vad.parameters.offset = 0.6
config.diarizer.vad.parameters.pad_offset = -0.05
config.diarizer.msdd_model.model_path = (
"diar_msdd_telephonic" # Telephonic speaker diarization model
)
return config
def get_word_ts_anchor(s, e, option="start"):
if option == "end":
return e
elif option == "mid":
return (s + e) / 2
return s
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
s, e, sp = spk_ts[0]
wrd_pos, turn_idx = 0, 0
wrd_spk_mapping = []
for wrd_dict in wrd_ts:
ws, we, wrd = (
int(wrd_dict["start"] * 1000),
int(wrd_dict["end"] * 1000),
wrd_dict["text"],
)
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
while wrd_pos > float(e):
turn_idx += 1
turn_idx = min(turn_idx, len(spk_ts) - 1)
s, e, sp = spk_ts[turn_idx]
if turn_idx == len(spk_ts) - 1:
e = get_word_ts_anchor(ws, we, option="end")
wrd_spk_mapping.append(
{"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
)
return wrd_spk_mapping
sentence_ending_punctuations = ".?!"
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
left_idx = word_idx
while (
left_idx > 0
and word_idx - left_idx < max_words
and speaker_list[left_idx - 1] == speaker_list[left_idx]
and not is_word_sentence_end(left_idx - 1)
):
left_idx -= 1
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
right_idx = word_idx
while (
right_idx < len(word_list) - 1
and right_idx - word_idx < max_words
and not is_word_sentence_end(right_idx)
):
right_idx += 1
return (
right_idx
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
else -1
)
def get_realigned_ws_mapping_with_punctuation(
word_speaker_mapping, max_words_in_sentence=50
):
is_word_sentence_end = (
lambda x: x >= 0
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
)
wsp_len = len(word_speaker_mapping)
words_list, speaker_list = [], []
for k, line_dict in enumerate(word_speaker_mapping):
word, speaker = line_dict["word"], line_dict["speaker"]
words_list.append(word)
speaker_list.append(speaker)
k = 0
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k]
if (
k < wsp_len - 1
and speaker_list[k] != speaker_list[k + 1]
and not is_word_sentence_end(k)
):
left_idx = get_first_word_idx_of_sentence(
k, words_list, speaker_list, max_words_in_sentence
)
right_idx = (
get_last_word_idx_of_sentence(
k, words_list, max_words_in_sentence - k + left_idx - 1
)
if left_idx > -1
else -1
)
if min(left_idx, right_idx) == -1:
k += 1
continue
spk_labels = speaker_list[left_idx : right_idx + 1]
mod_speaker = max(set(spk_labels), key=spk_labels.count)
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
k += 1
continue
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
right_idx - left_idx + 1
)
k = right_idx
k += 1
k, realigned_list = 0, []
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k].copy()
line_dict["speaker"] = speaker_list[k]
realigned_list.append(line_dict)
k += 1
return realigned_list
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
s, e, spk = spk_ts[0]
prev_spk = spk
snts = []
snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}
for wrd_dict in word_speaker_mapping:
wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
s, e = wrd_dict["start_time"], wrd_dict["end_time"]
if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
snts.append(snt)
snt = {
"speaker": f"Speaker {spk}",
"start_time": s,
"end_time": e,
"text": "",
}
else:
snt["end_time"] = e
snt["text"] += wrd + " "
prev_spk = spk
snts.append(snt)
return snts
def get_speaker_aware_transcript(sentences_speaker_mapping, f):
previous_speaker = sentences_speaker_mapping[0]["speaker"]
f.write(f"{previous_speaker}: ")
for sentence_dict in sentences_speaker_mapping:
speaker = sentence_dict["speaker"]
sentence = sentence_dict["text"]
# If this speaker doesn't match the previous one, start a new paragraph
if speaker != previous_speaker:
f.write(f"\n\n{speaker}: ")
previous_speaker = speaker
# No matter what, write the current sentence
f.write(sentence + " ")
def format_timestamp(
milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert milliseconds >= 0, "non-negative timestamp expected"
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def write_srt(transcript, file):
"""
Write a transcript to a file in SRT format.
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [
-1,
]
for token, token_id in tokenizer.get_vocab().items():
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(token_id)
return numeral_symbol_tokens
def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
# if current word is the last word
if current_word_index == len(word_timestamps) - 1:
return word_timestamps[current_word_index]["start"]
next_word_index = current_word_index + 1
while current_word_index < len(word_timestamps) - 1:
if word_timestamps[next_word_index].get("start") is None:
# if next word doesn't have a start timestamp
# merge it with the current word and delete it
word_timestamps[current_word_index]["word"] += (
" " + word_timestamps[next_word_index]["word"]
)
word_timestamps[next_word_index]["word"] = None
next_word_index += 1
if next_word_index == len(word_timestamps):
return final_timestamp
else:
return word_timestamps[next_word_index]["start"]
def filter_missing_timestamps(
word_timestamps, initial_timestamp=0, final_timestamp=None
):
# handle the first and last word
if word_timestamps[0].get("start") is None:
word_timestamps[0]["start"] = (
initial_timestamp if initial_timestamp is not None else 0
)
word_timestamps[0]["end"] = _get_next_start_timestamp(
word_timestamps, 0, final_timestamp
)
result = [
word_timestamps[0],
]
for i, ws in enumerate(word_timestamps[1:], start=1):
# if ws doesn't have a start and end
# use the previous end as start and next start as end
if ws.get("start") is None and ws.get("word") is not None:
ws["start"] = word_timestamps[i - 1]["end"]
ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)
if ws["word"] is not None:
result.append(ws)
return result
def cleanup(path: str):
"""path could either be relative or absolute."""
# check if file or directory exists
if os.path.isfile(path) or os.path.islink(path):
# remove file
os.remove(path)
elif os.path.isdir(path):
# remove directory and all its content
shutil.rmtree(path)
else:
raise ValueError("Path {} is not a file or dir.".format(path))
def process_language_arg(language: str, model_name: str):
"""
Process the language argument to make sure it's valid and convert language names to language codes.
"""
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if model_name.endswith(".en") and language != "en":
if language is not None:
logging.warning(
f"{model_name} is an English-only model but received '{language}'; using English instead."
)
language = "en"
return language