johann22's picture
Upload 13 files
36501b5 verified
raw
history blame
15.8 kB
import json
import logging
import os
import shutil
import nltk
import wget
from omegaconf import OmegaConf
from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE
punct_model_langs = [
"en",
"fr",
"de",
"es",
"it",
"nl",
"pt",
"bg",
"pl",
"cs",
"sk",
"sl",
]
wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
DEFAULT_ALIGN_MODELS_HF.keys()
)
whisper_langs = sorted(LANGUAGES.keys()) + sorted(
[k.title() for k in TO_LANGUAGE_CODE.keys()]
)
langs_to_iso = {
"aa": "aar",
"ab": "abk",
"ae": "ave",
"af": "afr",
"ak": "aka",
"am": "amh",
"an": "arg",
"ar": "ara",
"as": "asm",
"av": "ava",
"ay": "aym",
"az": "aze",
"ba": "bak",
"be": "bel",
"bg": "bul",
"bh": "bih",
"bi": "bis",
"bm": "bam",
"bn": "ben",
"bo": "tib",
"br": "bre",
"bs": "bos",
"ca": "cat",
"ce": "che",
"ch": "cha",
"co": "cos",
"cr": "cre",
"cs": "cze",
"cu": "chu",
"cv": "chv",
"cy": "wel",
"da": "dan",
"de": "ger",
"dv": "div",
"dz": "dzo",
"ee": "ewe",
"el": "gre",
"en": "eng",
"eo": "epo",
"es": "spa",
"et": "est",
"eu": "baq",
"fa": "per",
"ff": "ful",
"fi": "fin",
"fj": "fij",
"fo": "fao",
"fr": "fre",
"fy": "fry",
"ga": "gle",
"gd": "gla",
"gl": "glg",
"gn": "grn",
"gu": "guj",
"gv": "glv",
"ha": "hau",
"he": "heb",
"hi": "hin",
"ho": "hmo",
"hr": "hrv",
"ht": "hat",
"hu": "hun",
"hy": "arm",
"hz": "her",
"ia": "ina",
"id": "ind",
"ie": "ile",
"ig": "ibo",
"ii": "iii",
"ik": "ipk",
"io": "ido",
"is": "ice",
"it": "ita",
"iu": "iku",
"ja": "jpn",
"jv": "jav",
"ka": "geo",
"kg": "kon",
"ki": "kik",
"kj": "kua",
"kk": "kaz",
"kl": "kal",
"km": "khm",
"kn": "kan",
"ko": "kor",
"kr": "kau",
"ks": "kas",
"ku": "kur",
"kv": "kom",
"kw": "cor",
"ky": "kir",
"la": "lat",
"lb": "ltz",
"lg": "lug",
"li": "lim",
"ln": "lin",
"lo": "lao",
"lt": "lit",
"lu": "lub",
"lv": "lav",
"mg": "mlg",
"mh": "mah",
"mi": "mao",
"mk": "mac",
"ml": "mal",
"mn": "mon",
"mr": "mar",
"ms": "may",
"mt": "mlt",
"my": "bur",
"na": "nau",
"nb": "nob",
"nd": "nde",
"ne": "nep",
"ng": "ndo",
"nl": "dut",
"nn": "nno",
"no": "nor",
"nr": "nbl",
"nv": "nav",
"ny": "nya",
"oc": "oci",
"oj": "oji",
"om": "orm",
"or": "ori",
"os": "oss",
"pa": "pan",
"pi": "pli",
"pl": "pol",
"ps": "pus",
"pt": "por",
"qu": "que",
"rm": "roh",
"rn": "run",
"ro": "rum",
"ru": "rus",
"rw": "kin",
"sa": "san",
"sc": "srd",
"sd": "snd",
"se": "sme",
"sg": "sag",
"si": "sin",
"sk": "slo",
"sl": "slv",
"sm": "smo",
"sn": "sna",
"so": "som",
"sq": "alb",
"sr": "srp",
"ss": "ssw",
"st": "sot",
"su": "sun",
"sv": "swe",
"sw": "swa",
"ta": "tam",
"te": "tel",
"tg": "tgk",
"th": "tha",
"ti": "tir",
"tk": "tuk",
"tl": "tgl",
"tn": "tsn",
"to": "ton",
"tr": "tur",
"ts": "tso",
"tt": "tat",
"tw": "twi",
"ty": "tah",
"ug": "uig",
"uk": "ukr",
"ur": "urd",
"uz": "uzb",
"ve": "ven",
"vi": "vie",
"vo": "vol",
"wa": "wln",
"wo": "wol",
"xh": "xho",
"yi": "yid",
"yo": "yor",
"za": "zha",
"zh": "chi",
"zu": "zul",
}
def create_config(output_dir):
DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file
CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs"
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME)
if not os.path.exists(MODEL_CONFIG_PATH):
os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True)
CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH)
config = OmegaConf.load(MODEL_CONFIG_PATH)
data_dir = os.path.join(output_dir, "data")
os.makedirs(data_dir, exist_ok=True)
meta = {
"audio_filepath": os.path.join(output_dir, "mono_file.wav"),
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"rttm_filepath": None,
"uem_filepath": None,
}
with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
json.dump(meta, fp)
fp.write("\n")
pretrained_vad = "vad_multilingual_marblenet"
pretrained_speaker_model = "titanet_large"
config.num_workers = 0
config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
config.diarizer.out_dir = (
output_dir # Directory to store intermediate files and prediction outputs
)
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.oracle_vad = (
False # compute VAD provided with model_path to vad config
)
config.diarizer.clustering.parameters.oracle_num_speakers = False
# Here, we use our in-house pretrained NeMo VAD model
config.diarizer.vad.model_path = pretrained_vad
config.diarizer.vad.parameters.onset = 0.8
config.diarizer.vad.parameters.offset = 0.6
config.diarizer.vad.parameters.pad_offset = -0.05
config.diarizer.msdd_model.model_path = (
"diar_msdd_telephonic" # Telephonic speaker diarization model
)
return config
def get_word_ts_anchor(s, e, option="start"):
if option == "end":
return e
elif option == "mid":
return (s + e) / 2
return s
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
s, e, sp = spk_ts[0]
wrd_pos, turn_idx = 0, 0
wrd_spk_mapping = []
for wrd_dict in wrd_ts:
ws, we, wrd = (
int(wrd_dict["start"] * 1000),
int(wrd_dict["end"] * 1000),
wrd_dict["text"],
)
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
while wrd_pos > float(e):
turn_idx += 1
turn_idx = min(turn_idx, len(spk_ts) - 1)
s, e, sp = spk_ts[turn_idx]
if turn_idx == len(spk_ts) - 1:
e = get_word_ts_anchor(ws, we, option="end")
wrd_spk_mapping.append(
{"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
)
return wrd_spk_mapping
sentence_ending_punctuations = ".?!"
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
left_idx = word_idx
while (
left_idx > 0
and word_idx - left_idx < max_words
and speaker_list[left_idx - 1] == speaker_list[left_idx]
and not is_word_sentence_end(left_idx - 1)
):
left_idx -= 1
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
right_idx = word_idx
while (
right_idx < len(word_list) - 1
and right_idx - word_idx < max_words
and not is_word_sentence_end(right_idx)
):
right_idx += 1
return (
right_idx
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
else -1
)
def get_realigned_ws_mapping_with_punctuation(
word_speaker_mapping, max_words_in_sentence=50
):
is_word_sentence_end = (
lambda x: x >= 0
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
)
wsp_len = len(word_speaker_mapping)
words_list, speaker_list = [], []
for k, line_dict in enumerate(word_speaker_mapping):
word, speaker = line_dict["word"], line_dict["speaker"]
words_list.append(word)
speaker_list.append(speaker)
k = 0
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k]
if (
k < wsp_len - 1
and speaker_list[k] != speaker_list[k + 1]
and not is_word_sentence_end(k)
):
left_idx = get_first_word_idx_of_sentence(
k, words_list, speaker_list, max_words_in_sentence
)
right_idx = (
get_last_word_idx_of_sentence(
k, words_list, max_words_in_sentence - k + left_idx - 1
)
if left_idx > -1
else -1
)
if min(left_idx, right_idx) == -1:
k += 1
continue
spk_labels = speaker_list[left_idx : right_idx + 1]
mod_speaker = max(set(spk_labels), key=spk_labels.count)
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
k += 1
continue
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
right_idx - left_idx + 1
)
k = right_idx
k += 1
k, realigned_list = 0, []
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k].copy()
line_dict["speaker"] = speaker_list[k]
realigned_list.append(line_dict)
k += 1
return realigned_list
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
s, e, spk = spk_ts[0]
prev_spk = spk
snts = []
snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}
for wrd_dict in word_speaker_mapping:
wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
s, e = wrd_dict["start_time"], wrd_dict["end_time"]
if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
snts.append(snt)
snt = {
"speaker": f"Speaker {spk}",
"start_time": s,
"end_time": e,
"text": "",
}
else:
snt["end_time"] = e
snt["text"] += wrd + " "
prev_spk = spk
snts.append(snt)
return snts
def get_speaker_aware_transcript(sentences_speaker_mapping, f):
previous_speaker = sentences_speaker_mapping[0]["speaker"]
f.write(f"{previous_speaker}: ")
for sentence_dict in sentences_speaker_mapping:
speaker = sentence_dict["speaker"]
sentence = sentence_dict["text"]
# If this speaker doesn't match the previous one, start a new paragraph
if speaker != previous_speaker:
f.write(f"\n\n{speaker}: ")
previous_speaker = speaker
# No matter what, write the current sentence
f.write(sentence + " ")
def format_timestamp(
milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert milliseconds >= 0, "non-negative timestamp expected"
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def write_srt(transcript, file):
"""
Write a transcript to a file in SRT format.
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [
-1,
]
for token, token_id in tokenizer.get_vocab().items():
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(token_id)
return numeral_symbol_tokens
def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
# if current word is the last word
if current_word_index == len(word_timestamps) - 1:
return word_timestamps[current_word_index]["start"]
next_word_index = current_word_index + 1
while current_word_index < len(word_timestamps) - 1:
if word_timestamps[next_word_index].get("start") is None:
# if next word doesn't have a start timestamp
# merge it with the current word and delete it
word_timestamps[current_word_index]["word"] += (
" " + word_timestamps[next_word_index]["word"]
)
word_timestamps[next_word_index]["word"] = None
next_word_index += 1
if next_word_index == len(word_timestamps):
return final_timestamp
else:
return word_timestamps[next_word_index]["start"]
def filter_missing_timestamps(
word_timestamps, initial_timestamp=0, final_timestamp=None
):
# handle the first and last word
if word_timestamps[0].get("start") is None:
word_timestamps[0]["start"] = (
initial_timestamp if initial_timestamp is not None else 0
)
word_timestamps[0]["end"] = _get_next_start_timestamp(
word_timestamps, 0, final_timestamp
)
result = [
word_timestamps[0],
]
for i, ws in enumerate(word_timestamps[1:], start=1):
# if ws doesn't have a start and end
# use the previous end as start and next start as end
if ws.get("start") is None and ws.get("word") is not None:
ws["start"] = word_timestamps[i - 1]["end"]
ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)
if ws["word"] is not None:
result.append(ws)
return result
def cleanup(path: str):
"""path could either be relative or absolute."""
# check if file or directory exists
if os.path.isfile(path) or os.path.islink(path):
# remove file
os.remove(path)
elif os.path.isdir(path):
# remove directory and all its content
shutil.rmtree(path)
else:
raise ValueError("Path {} is not a file or dir.".format(path))
def process_language_arg(language: str, model_name: str):
"""
Process the language argument to make sure it's valid and convert language names to language codes.
"""
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if model_name.endswith(".en") and language != "en":
if language is not None:
logging.warning(
f"{model_name} is an English-only model but received '{language}'; using English instead."
)
language = "en"
return language