Speaker-Diarization / diarizers /nemo_diarizer.py
philippemos's picture
change num_worker in nemo configs to enforce cpu usage
6bfc1e0
raw
history blame contribute delete
No virus
4.72 kB
"""
Nemo diarizer
"""
import os
import json
import wget
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from nemo.collections.asr.models import ClusteringDiarizer
from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object
from pyannote.core import notebook
from diarizers.diarizer import Diarizer
class NemoDiarizer(Diarizer):
"""Class for Nemo Diarizer"""
def __init__(self, audio_path: str, data_dir: str):
"""
Nemo diarizer class
Args:
audio_path (str): the path to the audio file
"""
self.audio_path = audio_path
self.data_dir = data_dir
self.diarization = None
self.manifest_dir = os.path.join(self.data_dir, 'input_manifest.json')
self.model_config = os.path.join(self.data_dir, 'offline_diarization.yaml')
if not os.path.exists(self.model_config):
config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/" \
"speaker_tasks/diarization/conf/offline_diarization.yaml"
self.model_config = wget.download(config_url, self.data_dir)
self.config = OmegaConf.load(self.model_config)
def _create_manifest_file(self):
"""
Function that creates inference manifest file
"""
meta = {
'audio_filepath': self.audio_path,
'offset': 0,
'duration': None,
'label': 'infer',
'text': '-',
'num_speakers': None,
'rttm_filepath': None,
'uem_filepath': None
}
with open(self.manifest_dir, 'w') as fp:
json.dump(meta, fp)
fp.write('\n')
def _apply_config(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
"""
Function that edits the inference configuration file
Args:
pretrained_speaker_model (str): the pre-trained embedding model options are
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
speaker_diarization/results.html
"""
pretrained_vad = 'vad_marblenet'
self.config.num_workers = 0
output_dir = os.path.join(self.data_dir, 'outputs')
self.config.diarizer.manifest_filepath = self.manifest_dir
self.config.diarizer.out_dir = output_dir
self.config.diarizer.ignore_overlap = False
self.config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
self.config.diarizer.speaker_embeddings.parameters.window_length_in_sec = 0.5
self.config.diarizer.speaker_embeddings.parameters.shift_length_in_sec = 0.25
self.config.diarizer.oracle_vad = False
self.config.diarizer.clustering.parameters.oracle_num_speakers = False
# Here we use our inhouse pretrained NeMo VAD
self.config.diarizer.vad.model_path = pretrained_vad
self.config.diarizer.vad.window_length_in_sec = 0.15
self.config.diarizer.vad.shift_length_in_sec = 0.01
self.config.diarizer.vad.parameters.onset = 0.8
self.config.diarizer.vad.parameters.offset = 0.6
self.config.diarizer.vad.parameters.min_duration_on = 0.1
self.config.diarizer.vad.parameters.min_duration_off = 0.4
def diarize_audio(self, pretrained_speaker_model: str = 'ecapa_tdnn'):
"""
function that diarizes the audio
Args:
pretrained_speaker_model (str): the pre-trained embedding model options are
('escapa_tdnn', vad_telephony_marblenet, titanet_large, ecapa_tdnn)
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/
speaker_diarization/results.html
"""
self._create_manifest_file()
self._apply_config(pretrained_speaker_model)
sd_model = ClusteringDiarizer(cfg=self.config)
sd_model.diarize()
audio_file_name_without_extension = os.path.basename(self.audio_path).rsplit('.', 1)[0]
output_diarization_pred = f'{self.data_dir}/outputs/pred_rttms/' \
f'{audio_file_name_without_extension}.rttm'
pred_labels = rttm_to_labels(output_diarization_pred)
self.diarization = labels_to_pyannote_object(pred_labels)
def get_diarization_figure(self) -> plt.gcf:
"""
Function that return the diarization figure
"""
if not self.diarization:
self.diarize_audio()
figure, ax = plt.subplots()
notebook.plot_annotation(self.diarization, ax=ax, time=True, legend=True)
return plt.gcf()