johann22 commited on
Commit
84d670e
1 Parent(s): fc4d8d2

Upload 13 files

Browse files
whisper_diarization_main/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.srt
3
+ *.txt
4
+ *.wav
5
+ *.mp3
6
+ *.m4a
7
+
whisper_diarization_main/LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2023, Mahmoud Ashraf
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
whisper_diarization_main/README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">Speaker Diarization Using OpenAI Whisper</h1>
2
+ <p align="center">
3
+ <a href="https://github.com/MahmoudAshraf97/whisper-diarization/stargazers">
4
+ <img src="https://img.shields.io/github/stars/MahmoudAshraf97/whisper-diarization.svg?colorA=orange&colorB=orange&logo=github"
5
+ alt="GitHub stars">
6
+ </a>
7
+ <a href="https://github.com/MahmoudAshraf97/whisper-diarization/issues">
8
+ <img src="https://img.shields.io/github/issues/MahmoudAshraf97/whisper-diarization.svg"
9
+ alt="GitHub issues">
10
+ </a>
11
+ <a href="https://github.com/MahmoudAshraf97/whisper-diarization/blob/master/LICENSE">
12
+ <img src="https://img.shields.io/github/license/MahmoudAshraf97/whisper-diarization.svg"
13
+ alt="GitHub license">
14
+ </a>
15
+ <a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2FMahmoudAshraf97%2Fwhisper-diarization">
16
+ <img src="https://img.shields.io/twitter/url/https/github.com/MahmoudAshraf97/whisper-diarization.svg?style=social" alt="Twitter">
17
+ </a>
18
+ </a>
19
+ <a href="https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb">
20
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
21
+ </a>
22
+
23
+ </p>
24
+
25
+
26
+ #
27
+ Speaker Diarization pipeline based on OpenAI Whisper
28
+ I'd like to thank [@m-bain](https://github.com/m-bain) for Batched Whisper Inference, [@mu4farooqi](https://github.com/mu4farooqi) for punctuation realignment algorithm
29
+
30
+ <img src="https://github.blog/wp-content/uploads/2020/09/github-stars-logo_Color.png" alt="drawing" width="25"/> **Please, star the project on github (see top-right corner) if you appreciate my contribution to the community!**
31
+
32
+ ## What is it
33
+ This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using WhisperX to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by WhisperX to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts.
34
+
35
+
36
+ WhisperX and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later
37
+ ## Installation
38
+ `PyTorch`, `FFMPEG` and `Cython` are needed as prerequisites to install the requirements
39
+ ```
40
+ pip install cython torch
41
+ ```
42
+ or
43
+ ```
44
+ pip install torch
45
+ sudo apt update && sudo apt install cython3
46
+ ```
47
+ ```
48
+ # on Ubuntu or Debian
49
+ sudo apt update && sudo apt install ffmpeg
50
+
51
+ # on Arch Linux
52
+ sudo pacman -S ffmpeg
53
+
54
+ # on MacOS using Homebrew (https://brew.sh/)
55
+ brew install ffmpeg
56
+
57
+ # on Windows using Chocolatey (https://chocolatey.org/)
58
+ choco install ffmpeg
59
+
60
+ # on Windows using Scoop (https://scoop.sh/)
61
+ scoop install ffmpeg
62
+ ```
63
+ ```
64
+ pip install -r requirements.txt
65
+ ```
66
+ ## Usage
67
+
68
+ ```
69
+ python diarize.py -a AUDIO_FILE_NAME
70
+ ```
71
+
72
+ If your system has enough VRAM (>=10GB), you can use `diarize_parallel.py` instead, the difference is that it runs NeMo in parallel with Whisper, this can be beneficial in some cases and the result is the same since the two models are nondependent on each other. This is still experimental, so expect errors and sharp edges. Your feedback is welcome.
73
+
74
+ ## Command Line Options
75
+
76
+ - `-a AUDIO_FILE_NAME`: The name of the audio file to be processed
77
+ - `--no-stem`: Disables source separation
78
+ - `--whisper-model`: The model to be used for ASR, default is `medium.en`
79
+ - `--suppress_numerals`: Transcribes numbers in their pronounced letters instead of digits, improves alignment accuracy
80
+ - `--device`: Choose which device to use, defaults to "cuda" if available
81
+ - `--language`: Manually select language, useful if language detection failed
82
+ - `--batch-size`: Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference
83
+
84
+ ## Known Limitations
85
+ - Overlapping speakers are yet to be addressed, a possible approach would be to separate the audio file and isolate only one speaker, then feed it into the pipeline but this will need much more computation
86
+ - There might be some errors, please raise an issue if you encounter any.
87
+
88
+ ## Future Improvements
89
+ - Implement a maximum length per sentence for SRT
90
+
91
+ ## Acknowledgements
92
+ Special Thanks for [@adamjonas](https://github.com/adamjonas) for supporting this project
93
+ This work is based on [OpenAI's Whisper](https://github.com/openai/whisper) , [Faster Whisper](https://github.com/guillaumekln/faster-whisper) , [Nvidia NeMo](https://github.com/NVIDIA/NeMo) , and [Facebook's Demucs](https://github.com/facebookresearch/demucs)
whisper_diarization_main/Whisper_Transcription_+_NeMo_Diarization.ipynb ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {
7
+ "colab_type": "text",
8
+ "id": "view-in-github"
9
+ },
10
+ "source": [
11
+ "<a href=\"https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
12
+ ]
13
+ },
14
+ {
15
+ "attachments": {},
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "eCmjcOc9yEtQ"
19
+ },
20
+ "source": [
21
+ "# Installing Dependencies"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "Tn1c-CoDv2kw"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "!pip install git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n",
33
+ "!pip install --no-build-isolation nemo_toolkit[asr]==1.23.0\n",
34
+ "!pip install --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n",
35
+ "!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n",
36
+ "!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {
43
+ "id": "YzhncHP0ytbQ"
44
+ },
45
+ "outputs": [],
46
+ "source": [
47
+ "import os\n",
48
+ "import wget\n",
49
+ "from omegaconf import OmegaConf\n",
50
+ "import json\n",
51
+ "import shutil\n",
52
+ "import torch\n",
53
+ "import torchaudio\n",
54
+ "from nemo.collections.asr.models.msdd_models import NeuralDiarizer\n",
55
+ "from deepmultilingualpunctuation import PunctuationModel\n",
56
+ "import re\n",
57
+ "import logging\n",
58
+ "import nltk\n",
59
+ "from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE\n",
60
+ "from ctc_forced_aligner import (\n",
61
+ " load_alignment_model,\n",
62
+ " generate_emissions,\n",
63
+ " preprocess_text,\n",
64
+ " get_alignments,\n",
65
+ " get_spans,\n",
66
+ " postprocess_results,\n",
67
+ ")"
68
+ ]
69
+ },
70
+ {
71
+ "attachments": {},
72
+ "cell_type": "markdown",
73
+ "metadata": {
74
+ "id": "jbsUt3SwyhjD"
75
+ },
76
+ "source": [
77
+ "# Helper Functions"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "id": "Se6Hc7CZygxu"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "punct_model_langs = [\n",
89
+ " \"en\",\n",
90
+ " \"fr\",\n",
91
+ " \"de\",\n",
92
+ " \"es\",\n",
93
+ " \"it\",\n",
94
+ " \"nl\",\n",
95
+ " \"pt\",\n",
96
+ " \"bg\",\n",
97
+ " \"pl\",\n",
98
+ " \"cs\",\n",
99
+ " \"sk\",\n",
100
+ " \"sl\",\n",
101
+ "]\n",
102
+ "langs_to_iso = {\n",
103
+ " \"af\": \"afr\",\n",
104
+ " \"am\": \"amh\",\n",
105
+ " \"ar\": \"ara\",\n",
106
+ " \"as\": \"asm\",\n",
107
+ " \"az\": \"aze\",\n",
108
+ " \"ba\": \"bak\",\n",
109
+ " \"be\": \"bel\",\n",
110
+ " \"bg\": \"bul\",\n",
111
+ " \"bn\": \"ben\",\n",
112
+ " \"bo\": \"tib\",\n",
113
+ " \"br\": \"bre\",\n",
114
+ " \"bs\": \"bos\",\n",
115
+ " \"ca\": \"cat\",\n",
116
+ " \"cs\": \"cze\",\n",
117
+ " \"cy\": \"wel\",\n",
118
+ " \"da\": \"dan\",\n",
119
+ " \"de\": \"ger\",\n",
120
+ " \"el\": \"gre\",\n",
121
+ " \"en\": \"eng\",\n",
122
+ " \"es\": \"spa\",\n",
123
+ " \"et\": \"est\",\n",
124
+ " \"eu\": \"baq\",\n",
125
+ " \"fa\": \"per\",\n",
126
+ " \"fi\": \"fin\",\n",
127
+ " \"fo\": \"fao\",\n",
128
+ " \"fr\": \"fre\",\n",
129
+ " \"gl\": \"glg\",\n",
130
+ " \"gu\": \"guj\",\n",
131
+ " \"ha\": \"hau\",\n",
132
+ " \"haw\": \"haw\",\n",
133
+ " \"he\": \"heb\",\n",
134
+ " \"hi\": \"hin\",\n",
135
+ " \"hr\": \"hrv\",\n",
136
+ " \"ht\": \"hat\",\n",
137
+ " \"hu\": \"hun\",\n",
138
+ " \"hy\": \"arm\",\n",
139
+ " \"id\": \"ind\",\n",
140
+ " \"is\": \"ice\",\n",
141
+ " \"it\": \"ita\",\n",
142
+ " \"ja\": \"jpn\",\n",
143
+ " \"jw\": \"jav\",\n",
144
+ " \"ka\": \"geo\",\n",
145
+ " \"kk\": \"kaz\",\n",
146
+ " \"km\": \"khm\",\n",
147
+ " \"kn\": \"kan\",\n",
148
+ " \"ko\": \"kor\",\n",
149
+ " \"la\": \"lat\",\n",
150
+ " \"lb\": \"ltz\",\n",
151
+ " \"ln\": \"lin\",\n",
152
+ " \"lo\": \"lao\",\n",
153
+ " \"lt\": \"lit\",\n",
154
+ " \"lv\": \"lav\",\n",
155
+ " \"mg\": \"mlg\",\n",
156
+ " \"mi\": \"mao\",\n",
157
+ " \"mk\": \"mac\",\n",
158
+ " \"ml\": \"mal\",\n",
159
+ " \"mn\": \"mon\",\n",
160
+ " \"mr\": \"mar\",\n",
161
+ " \"ms\": \"may\",\n",
162
+ " \"mt\": \"mlt\",\n",
163
+ " \"my\": \"bur\",\n",
164
+ " \"ne\": \"nep\",\n",
165
+ " \"nl\": \"dut\",\n",
166
+ " \"nn\": \"nno\",\n",
167
+ " \"no\": \"nor\",\n",
168
+ " \"oc\": \"oci\",\n",
169
+ " \"pa\": \"pan\",\n",
170
+ " \"pl\": \"pol\",\n",
171
+ " \"ps\": \"pus\",\n",
172
+ " \"pt\": \"por\",\n",
173
+ " \"ro\": \"rum\",\n",
174
+ " \"ru\": \"rus\",\n",
175
+ " \"sa\": \"san\",\n",
176
+ " \"sd\": \"snd\",\n",
177
+ " \"si\": \"sin\",\n",
178
+ " \"sk\": \"slo\",\n",
179
+ " \"sl\": \"slv\",\n",
180
+ " \"sn\": \"sna\",\n",
181
+ " \"so\": \"som\",\n",
182
+ " \"sq\": \"alb\",\n",
183
+ " \"sr\": \"srp\",\n",
184
+ " \"su\": \"sun\",\n",
185
+ " \"sv\": \"swe\",\n",
186
+ " \"sw\": \"swa\",\n",
187
+ " \"ta\": \"tam\",\n",
188
+ " \"te\": \"tel\",\n",
189
+ " \"tg\": \"tgk\",\n",
190
+ " \"th\": \"tha\",\n",
191
+ " \"tk\": \"tuk\",\n",
192
+ " \"tl\": \"tgl\",\n",
193
+ " \"tr\": \"tur\",\n",
194
+ " \"tt\": \"tat\",\n",
195
+ " \"uk\": \"ukr\",\n",
196
+ " \"ur\": \"urd\",\n",
197
+ " \"uz\": \"uzb\",\n",
198
+ " \"vi\": \"vie\",\n",
199
+ " \"yi\": \"yid\",\n",
200
+ " \"yo\": \"yor\",\n",
201
+ " \"yue\": \"yue\",\n",
202
+ " \"zh\": \"chi\",\n",
203
+ "}\n",
204
+ "\n",
205
+ "\n",
206
+ "whisper_langs = sorted(LANGUAGES.keys()) + sorted(\n",
207
+ " [k.title() for k in TO_LANGUAGE_CODE.keys()]\n",
208
+ ")\n",
209
+ "\n",
210
+ "\n",
211
+ "def create_config(output_dir):\n",
212
+ " DOMAIN_TYPE = \"telephonic\" # Can be meeting, telephonic, or general based on domain type of the audio file\n",
213
+ " CONFIG_FILE_NAME = f\"diar_infer_{DOMAIN_TYPE}.yaml\"\n",
214
+ " CONFIG_URL = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}\"\n",
215
+ " MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)\n",
216
+ " if not os.path.exists(MODEL_CONFIG):\n",
217
+ " MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)\n",
218
+ "\n",
219
+ " config = OmegaConf.load(MODEL_CONFIG)\n",
220
+ "\n",
221
+ " data_dir = os.path.join(output_dir, \"data\")\n",
222
+ " os.makedirs(data_dir, exist_ok=True)\n",
223
+ "\n",
224
+ " meta = {\n",
225
+ " \"audio_filepath\": os.path.join(output_dir, \"mono_file.wav\"),\n",
226
+ " \"offset\": 0,\n",
227
+ " \"duration\": None,\n",
228
+ " \"label\": \"infer\",\n",
229
+ " \"text\": \"-\",\n",
230
+ " \"rttm_filepath\": None,\n",
231
+ " \"uem_filepath\": None,\n",
232
+ " }\n",
233
+ " with open(os.path.join(data_dir, \"input_manifest.json\"), \"w\") as fp:\n",
234
+ " json.dump(meta, fp)\n",
235
+ " fp.write(\"\\n\")\n",
236
+ "\n",
237
+ " pretrained_vad = \"vad_multilingual_marblenet\"\n",
238
+ " pretrained_speaker_model = \"titanet_large\"\n",
239
+ " config.num_workers = 0 # Workaround for multiprocessing hanging with ipython issue\n",
240
+ " config.diarizer.manifest_filepath = os.path.join(data_dir, \"input_manifest.json\")\n",
241
+ " config.diarizer.out_dir = (\n",
242
+ " output_dir # Directory to store intermediate files and prediction outputs\n",
243
+ " )\n",
244
+ "\n",
245
+ " config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
246
+ " config.diarizer.oracle_vad = (\n",
247
+ " False # compute VAD provided with model_path to vad config\n",
248
+ " )\n",
249
+ " config.diarizer.clustering.parameters.oracle_num_speakers = False\n",
250
+ "\n",
251
+ " # Here, we use our in-house pretrained NeMo VAD model\n",
252
+ " config.diarizer.vad.model_path = pretrained_vad\n",
253
+ " config.diarizer.vad.parameters.onset = 0.8\n",
254
+ " config.diarizer.vad.parameters.offset = 0.6\n",
255
+ " config.diarizer.vad.parameters.pad_offset = -0.05\n",
256
+ " config.diarizer.msdd_model.model_path = (\n",
257
+ " \"diar_msdd_telephonic\" # Telephonic speaker diarization model\n",
258
+ " )\n",
259
+ "\n",
260
+ " return config\n",
261
+ "\n",
262
+ "\n",
263
+ "def get_word_ts_anchor(s, e, option=\"start\"):\n",
264
+ " if option == \"end\":\n",
265
+ " return e\n",
266
+ " elif option == \"mid\":\n",
267
+ " return (s + e) / 2\n",
268
+ " return s\n",
269
+ "\n",
270
+ "\n",
271
+ "def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option=\"start\"):\n",
272
+ " s, e, sp = spk_ts[0]\n",
273
+ " wrd_pos, turn_idx = 0, 0\n",
274
+ " wrd_spk_mapping = []\n",
275
+ " for wrd_dict in wrd_ts:\n",
276
+ " ws, we, wrd = (\n",
277
+ " int(wrd_dict[\"start\"] * 1000),\n",
278
+ " int(wrd_dict[\"end\"] * 1000),\n",
279
+ " wrd_dict[\"text\"],\n",
280
+ " )\n",
281
+ " wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)\n",
282
+ " while wrd_pos > float(e):\n",
283
+ " turn_idx += 1\n",
284
+ " turn_idx = min(turn_idx, len(spk_ts) - 1)\n",
285
+ " s, e, sp = spk_ts[turn_idx]\n",
286
+ " if turn_idx == len(spk_ts) - 1:\n",
287
+ " e = get_word_ts_anchor(ws, we, option=\"end\")\n",
288
+ " wrd_spk_mapping.append(\n",
289
+ " {\"word\": wrd, \"start_time\": ws, \"end_time\": we, \"speaker\": sp}\n",
290
+ " )\n",
291
+ " return wrd_spk_mapping\n",
292
+ "\n",
293
+ "\n",
294
+ "sentence_ending_punctuations = \".?!\"\n",
295
+ "\n",
296
+ "\n",
297
+ "def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):\n",
298
+ " is_word_sentence_end = (\n",
299
+ " lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
300
+ " )\n",
301
+ " left_idx = word_idx\n",
302
+ " while (\n",
303
+ " left_idx > 0\n",
304
+ " and word_idx - left_idx < max_words\n",
305
+ " and speaker_list[left_idx - 1] == speaker_list[left_idx]\n",
306
+ " and not is_word_sentence_end(left_idx - 1)\n",
307
+ " ):\n",
308
+ " left_idx -= 1\n",
309
+ "\n",
310
+ " return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1\n",
311
+ "\n",
312
+ "\n",
313
+ "def get_last_word_idx_of_sentence(word_idx, word_list, max_words):\n",
314
+ " is_word_sentence_end = (\n",
315
+ " lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
316
+ " )\n",
317
+ " right_idx = word_idx\n",
318
+ " while (\n",
319
+ " right_idx < len(word_list) - 1\n",
320
+ " and right_idx - word_idx < max_words\n",
321
+ " and not is_word_sentence_end(right_idx)\n",
322
+ " ):\n",
323
+ " right_idx += 1\n",
324
+ "\n",
325
+ " return (\n",
326
+ " right_idx\n",
327
+ " if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)\n",
328
+ " else -1\n",
329
+ " )\n",
330
+ "\n",
331
+ "\n",
332
+ "def get_realigned_ws_mapping_with_punctuation(\n",
333
+ " word_speaker_mapping, max_words_in_sentence=50\n",
334
+ "):\n",
335
+ " is_word_sentence_end = (\n",
336
+ " lambda x: x >= 0\n",
337
+ " and word_speaker_mapping[x][\"word\"][-1] in sentence_ending_punctuations\n",
338
+ " )\n",
339
+ " wsp_len = len(word_speaker_mapping)\n",
340
+ "\n",
341
+ " words_list, speaker_list = [], []\n",
342
+ " for k, line_dict in enumerate(word_speaker_mapping):\n",
343
+ " word, speaker = line_dict[\"word\"], line_dict[\"speaker\"]\n",
344
+ " words_list.append(word)\n",
345
+ " speaker_list.append(speaker)\n",
346
+ "\n",
347
+ " k = 0\n",
348
+ " while k < len(word_speaker_mapping):\n",
349
+ " line_dict = word_speaker_mapping[k]\n",
350
+ " if (\n",
351
+ " k < wsp_len - 1\n",
352
+ " and speaker_list[k] != speaker_list[k + 1]\n",
353
+ " and not is_word_sentence_end(k)\n",
354
+ " ):\n",
355
+ " left_idx = get_first_word_idx_of_sentence(\n",
356
+ " k, words_list, speaker_list, max_words_in_sentence\n",
357
+ " )\n",
358
+ " right_idx = (\n",
359
+ " get_last_word_idx_of_sentence(\n",
360
+ " k, words_list, max_words_in_sentence - k + left_idx - 1\n",
361
+ " )\n",
362
+ " if left_idx > -1\n",
363
+ " else -1\n",
364
+ " )\n",
365
+ " if min(left_idx, right_idx) == -1:\n",
366
+ " k += 1\n",
367
+ " continue\n",
368
+ "\n",
369
+ " spk_labels = speaker_list[left_idx : right_idx + 1]\n",
370
+ " mod_speaker = max(set(spk_labels), key=spk_labels.count)\n",
371
+ " if spk_labels.count(mod_speaker) < len(spk_labels) // 2:\n",
372
+ " k += 1\n",
373
+ " continue\n",
374
+ "\n",
375
+ " speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (\n",
376
+ " right_idx - left_idx + 1\n",
377
+ " )\n",
378
+ " k = right_idx\n",
379
+ "\n",
380
+ " k += 1\n",
381
+ "\n",
382
+ " k, realigned_list = 0, []\n",
383
+ " while k < len(word_speaker_mapping):\n",
384
+ " line_dict = word_speaker_mapping[k].copy()\n",
385
+ " line_dict[\"speaker\"] = speaker_list[k]\n",
386
+ " realigned_list.append(line_dict)\n",
387
+ " k += 1\n",
388
+ "\n",
389
+ " return realigned_list\n",
390
+ "\n",
391
+ "\n",
392
+ "def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):\n",
393
+ " sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak\n",
394
+ " s, e, spk = spk_ts[0]\n",
395
+ " prev_spk = spk\n",
396
+ "\n",
397
+ " snts = []\n",
398
+ " snt = {\"speaker\": f\"Speaker {spk}\", \"start_time\": s, \"end_time\": e, \"text\": \"\"}\n",
399
+ "\n",
400
+ " for wrd_dict in word_speaker_mapping:\n",
401
+ " wrd, spk = wrd_dict[\"word\"], wrd_dict[\"speaker\"]\n",
402
+ " s, e = wrd_dict[\"start_time\"], wrd_dict[\"end_time\"]\n",
403
+ " if spk != prev_spk or sentence_checker(snt[\"text\"] + \" \" + wrd):\n",
404
+ " snts.append(snt)\n",
405
+ " snt = {\n",
406
+ " \"speaker\": f\"Speaker {spk}\",\n",
407
+ " \"start_time\": s,\n",
408
+ " \"end_time\": e,\n",
409
+ " \"text\": \"\",\n",
410
+ " }\n",
411
+ " else:\n",
412
+ " snt[\"end_time\"] = e\n",
413
+ " snt[\"text\"] += wrd + \" \"\n",
414
+ " prev_spk = spk\n",
415
+ "\n",
416
+ " snts.append(snt)\n",
417
+ " return snts\n",
418
+ "\n",
419
+ "\n",
420
+ "def get_speaker_aware_transcript(sentences_speaker_mapping, f):\n",
421
+ " previous_speaker = sentences_speaker_mapping[0][\"speaker\"]\n",
422
+ " f.write(f\"{previous_speaker}: \")\n",
423
+ "\n",
424
+ " for sentence_dict in sentences_speaker_mapping:\n",
425
+ " speaker = sentence_dict[\"speaker\"]\n",
426
+ " sentence = sentence_dict[\"text\"]\n",
427
+ "\n",
428
+ " # If this speaker doesn't match the previous one, start a new paragraph\n",
429
+ " if speaker != previous_speaker:\n",
430
+ " f.write(f\"\\n\\n{speaker}: \")\n",
431
+ " previous_speaker = speaker\n",
432
+ "\n",
433
+ " # No matter what, write the current sentence\n",
434
+ " f.write(sentence + \" \")\n",
435
+ "\n",
436
+ "\n",
437
+ "def format_timestamp(\n",
438
+ " milliseconds: float, always_include_hours: bool = False, decimal_marker: str = \".\"\n",
439
+ "):\n",
440
+ " assert milliseconds >= 0, \"non-negative timestamp expected\"\n",
441
+ "\n",
442
+ " hours = milliseconds // 3_600_000\n",
443
+ " milliseconds -= hours * 3_600_000\n",
444
+ "\n",
445
+ " minutes = milliseconds // 60_000\n",
446
+ " milliseconds -= minutes * 60_000\n",
447
+ "\n",
448
+ " seconds = milliseconds // 1_000\n",
449
+ " milliseconds -= seconds * 1_000\n",
450
+ "\n",
451
+ " hours_marker = f\"{hours:02d}:\" if always_include_hours or hours > 0 else \"\"\n",
452
+ " return (\n",
453
+ " f\"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}\"\n",
454
+ " )\n",
455
+ "\n",
456
+ "\n",
457
+ "def write_srt(transcript, file):\n",
458
+ " \"\"\"\n",
459
+ " Write a transcript to a file in SRT format.\n",
460
+ "\n",
461
+ " \"\"\"\n",
462
+ " for i, segment in enumerate(transcript, start=1):\n",
463
+ " # write srt lines\n",
464
+ " print(\n",
465
+ " f\"{i}\\n\"\n",
466
+ " f\"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> \"\n",
467
+ " f\"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\\n\"\n",
468
+ " f\"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\\n\",\n",
469
+ " file=file,\n",
470
+ " flush=True,\n",
471
+ " )\n",
472
+ "\n",
473
+ "\n",
474
+ "def find_numeral_symbol_tokens(tokenizer):\n",
475
+ " numeral_symbol_tokens = [\n",
476
+ " -1,\n",
477
+ " ]\n",
478
+ " for token, token_id in tokenizer.get_vocab().items():\n",
479
+ " has_numeral_symbol = any(c in \"0123456789%$£\" for c in token)\n",
480
+ " if has_numeral_symbol:\n",
481
+ " numeral_symbol_tokens.append(token_id)\n",
482
+ " return numeral_symbol_tokens\n",
483
+ "\n",
484
+ "\n",
485
+ "def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):\n",
486
+ " # if current word is the last word\n",
487
+ " if current_word_index == len(word_timestamps) - 1:\n",
488
+ " return word_timestamps[current_word_index][\"start\"]\n",
489
+ "\n",
490
+ " next_word_index = current_word_index + 1\n",
491
+ " while current_word_index < len(word_timestamps) - 1:\n",
492
+ " if word_timestamps[next_word_index].get(\"start\") is None:\n",
493
+ " # if next word doesn't have a start timestamp\n",
494
+ " # merge it with the current word and delete it\n",
495
+ " word_timestamps[current_word_index][\"word\"] += (\n",
496
+ " \" \" + word_timestamps[next_word_index][\"word\"]\n",
497
+ " )\n",
498
+ "\n",
499
+ " word_timestamps[next_word_index][\"word\"] = None\n",
500
+ " next_word_index += 1\n",
501
+ " if next_word_index == len(word_timestamps):\n",
502
+ " return final_timestamp\n",
503
+ "\n",
504
+ " else:\n",
505
+ " return word_timestamps[next_word_index][\"start\"]\n",
506
+ "\n",
507
+ "\n",
508
+ "def filter_missing_timestamps(\n",
509
+ " word_timestamps, initial_timestamp=0, final_timestamp=None\n",
510
+ "):\n",
511
+ " # handle the first and last word\n",
512
+ " if word_timestamps[0].get(\"start\") is None:\n",
513
+ " word_timestamps[0][\"start\"] = (\n",
514
+ " initial_timestamp if initial_timestamp is not None else 0\n",
515
+ " )\n",
516
+ " word_timestamps[0][\"end\"] = _get_next_start_timestamp(\n",
517
+ " word_timestamps, 0, final_timestamp\n",
518
+ " )\n",
519
+ "\n",
520
+ " result = [\n",
521
+ " word_timestamps[0],\n",
522
+ " ]\n",
523
+ "\n",
524
+ " for i, ws in enumerate(word_timestamps[1:], start=1):\n",
525
+ " # if ws doesn't have a start and end\n",
526
+ " # use the previous end as start and next start as end\n",
527
+ " if ws.get(\"start\") is None and ws.get(\"word\") is not None:\n",
528
+ " ws[\"start\"] = word_timestamps[i - 1][\"end\"]\n",
529
+ " ws[\"end\"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)\n",
530
+ "\n",
531
+ " if ws[\"word\"] is not None:\n",
532
+ " result.append(ws)\n",
533
+ " return result\n",
534
+ "\n",
535
+ "\n",
536
+ "def cleanup(path: str):\n",
537
+ " \"\"\"path could either be relative or absolute.\"\"\"\n",
538
+ " # check if file or directory exists\n",
539
+ " if os.path.isfile(path) or os.path.islink(path):\n",
540
+ " # remove file\n",
541
+ " os.remove(path)\n",
542
+ " elif os.path.isdir(path):\n",
543
+ " # remove directory and all its content\n",
544
+ " shutil.rmtree(path)\n",
545
+ " else:\n",
546
+ " raise ValueError(\"Path {} is not a file or dir.\".format(path))\n",
547
+ "\n",
548
+ "\n",
549
+ "def process_language_arg(language: str, model_name: str):\n",
550
+ " \"\"\"\n",
551
+ " Process the language argument to make sure it's valid and convert language names to language codes.\n",
552
+ " \"\"\"\n",
553
+ " if language is not None:\n",
554
+ " language = language.lower()\n",
555
+ " if language not in LANGUAGES:\n",
556
+ " if language in TO_LANGUAGE_CODE:\n",
557
+ " language = TO_LANGUAGE_CODE[language]\n",
558
+ " else:\n",
559
+ " raise ValueError(f\"Unsupported language: {language}\")\n",
560
+ "\n",
561
+ " if model_name.endswith(\".en\") and language != \"en\":\n",
562
+ " if language is not None:\n",
563
+ " logging.warning(\n",
564
+ " f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
565
+ " )\n",
566
+ " language = \"en\"\n",
567
+ " return language\n",
568
+ "\n",
569
+ "\n",
570
+ "def transcribe_batched(\n",
571
+ " audio_file: str,\n",
572
+ " language: str,\n",
573
+ " batch_size: int,\n",
574
+ " model_name: str,\n",
575
+ " compute_dtype: str,\n",
576
+ " suppress_numerals: bool,\n",
577
+ " device: str,\n",
578
+ "):\n",
579
+ " import whisperx\n",
580
+ "\n",
581
+ " # Faster Whisper batched\n",
582
+ " whisper_model = whisperx.load_model(\n",
583
+ " model_name,\n",
584
+ " device,\n",
585
+ " compute_type=compute_dtype,\n",
586
+ " asr_options={\"suppress_numerals\": suppress_numerals},\n",
587
+ " )\n",
588
+ " audio = whisperx.load_audio(audio_file)\n",
589
+ " result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
590
+ " del whisper_model\n",
591
+ " torch.cuda.empty_cache()\n",
592
+ " return result[\"segments\"], result[\"language\"], audio"
593
+ ]
594
+ },
595
+ {
596
+ "attachments": {},
597
+ "cell_type": "markdown",
598
+ "metadata": {
599
+ "id": "B7qWQb--1Xcw"
600
+ },
601
+ "source": [
602
+ "# Options"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "metadata": {
609
+ "id": "ONlFrSnD0FOp"
610
+ },
611
+ "outputs": [],
612
+ "source": [
613
+ "# Name of the audio file\n",
614
+ "audio_path = \"20200128-Pieter Wuille (part 1 of 2) - Episode 1.mp3\"\n",
615
+ "\n",
616
+ "# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
617
+ "enable_stemming = True\n",
618
+ "\n",
619
+ "# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')\n",
620
+ "whisper_model_name = \"large-v2\"\n",
621
+ "\n",
622
+ "# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
623
+ "suppress_numerals = True\n",
624
+ "\n",
625
+ "batch_size = 8\n",
626
+ "\n",
627
+ "language = None # autodetect language\n",
628
+ "\n",
629
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
630
+ ]
631
+ },
632
+ {
633
+ "attachments": {},
634
+ "cell_type": "markdown",
635
+ "metadata": {
636
+ "id": "h-cY1ZEy2KVI"
637
+ },
638
+ "source": [
639
+ "# Processing"
640
+ ]
641
+ },
642
+ {
643
+ "attachments": {},
644
+ "cell_type": "markdown",
645
+ "metadata": {
646
+ "id": "7ZS4xXmE2NGP"
647
+ },
648
+ "source": [
649
+ "## Separating music from speech using Demucs\n",
650
+ "\n",
651
+ "---\n",
652
+ "\n",
653
+ "By isolating the vocals from the rest of the audio, it becomes easier to identify and track individual speakers based on the spectral and temporal characteristics of their speech signals. Source separation is just one of many techniques that can be used as a preprocessing step to help improve the accuracy and reliability of the overall diarization process."
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": null,
659
+ "metadata": {
660
+ "colab": {
661
+ "base_uri": "https://localhost:8080/"
662
+ },
663
+ "id": "HKcgQUrAzsJZ",
664
+ "outputId": "dc2a1d96-20da-4749-9d64-21edacfba1b1"
665
+ },
666
+ "outputs": [],
667
+ "source": [
668
+ "if enable_stemming:\n",
669
+ " # Isolate vocals from the rest of the audio\n",
670
+ "\n",
671
+ " return_code = os.system(\n",
672
+ " f'python3 -m demucs.separate -n htdemucs --two-stems=vocals \"{audio_path}\" -o \"temp_outputs\"'\n",
673
+ " )\n",
674
+ "\n",
675
+ " if return_code != 0:\n",
676
+ " logging.warning(\"Source splitting failed, using original audio file.\")\n",
677
+ " vocal_target = audio_path\n",
678
+ " else:\n",
679
+ " vocal_target = os.path.join(\n",
680
+ " \"temp_outputs\",\n",
681
+ " \"htdemucs\",\n",
682
+ " os.path.splitext(os.path.basename(audio_path))[0],\n",
683
+ " \"vocals.wav\",\n",
684
+ " )\n",
685
+ "else:\n",
686
+ " vocal_target = audio_path"
687
+ ]
688
+ },
689
+ {
690
+ "attachments": {},
691
+ "cell_type": "markdown",
692
+ "metadata": {
693
+ "id": "UYg9VWb22Tz8"
694
+ },
695
+ "source": [
696
+ "## Transcriping audio using Whisper and realligning timestamps using Wav2Vec2\n",
697
+ "---\n",
698
+ "This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.\n",
699
+ "\n",
700
+ "The first model is called OpenAI Whisper, which is a speech recognition model that can transcribe speech with high accuracy. The code loads the whisper model and uses it to transcribe the vocal_target file.\n",
701
+ "\n",
702
+ "The output of the transcription process is a set of text segments with corresponding timestamps indicating when each segment was spoken.\n"
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "code",
707
+ "execution_count": null,
708
+ "metadata": {
709
+ "id": "5-VKFn530oTl"
710
+ },
711
+ "outputs": [],
712
+ "source": [
713
+ "compute_type = \"float16\"\n",
714
+ "# or run on GPU with INT8\n",
715
+ "# compute_type = \"int8_float16\"\n",
716
+ "# or run on CPU with INT8\n",
717
+ "# compute_type = \"int8\"\n",
718
+ "\n",
719
+ "whisper_results, language, audio_waveform = transcribe_batched(\n",
720
+ " vocal_target,\n",
721
+ " language,\n",
722
+ " batch_size,\n",
723
+ " whisper_model_name,\n",
724
+ " compute_type,\n",
725
+ " suppress_numerals,\n",
726
+ " device,\n",
727
+ ")"
728
+ ]
729
+ },
730
+ {
731
+ "attachments": {},
732
+ "cell_type": "markdown",
733
+ "metadata": {},
734
+ "source": [
735
+ "## Aligning the transcription with the original audio using Wav2Vec2\n",
736
+ "---\n",
737
+ "The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.\n",
738
+ "\n",
739
+ "The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
740
+ "\n",
741
+ "By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.\n",
742
+ "\n",
743
+ "If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead."
744
+ ]
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "execution_count": null,
749
+ "metadata": {},
750
+ "outputs": [],
751
+ "source": [
752
+ "alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(\n",
753
+ " device,\n",
754
+ " dtype=torch.float16 if device == \"cuda\" else torch.float32,\n",
755
+ ")\n",
756
+ "\n",
757
+ "audio_waveform = (\n",
758
+ " torch.from_numpy(audio_waveform)\n",
759
+ " .to(alignment_model.dtype)\n",
760
+ " .to(alignment_model.device)\n",
761
+ ")\n",
762
+ "\n",
763
+ "emissions, stride = generate_emissions(\n",
764
+ " alignment_model, audio_waveform, batch_size=batch_size\n",
765
+ ")\n",
766
+ "\n",
767
+ "del alignment_model\n",
768
+ "torch.cuda.empty_cache()\n",
769
+ "\n",
770
+ "full_transcript = \"\".join(segment[\"text\"] for segment in whisper_results)\n",
771
+ "\n",
772
+ "tokens_starred, text_starred = preprocess_text(\n",
773
+ " full_transcript,\n",
774
+ " romanize=True,\n",
775
+ " language=langs_to_iso[language],\n",
776
+ ")\n",
777
+ "\n",
778
+ "segments, scores, blank_id = get_alignments(\n",
779
+ " emissions,\n",
780
+ " tokens_starred,\n",
781
+ " alignment_dictionary,\n",
782
+ ")\n",
783
+ "\n",
784
+ "spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))\n",
785
+ "\n",
786
+ "word_timestamps = postprocess_results(text_starred, spans, stride, scores)"
787
+ ]
788
+ },
789
+ {
790
+ "attachments": {},
791
+ "cell_type": "markdown",
792
+ "metadata": {
793
+ "id": "7EEaJPsQ21Rx"
794
+ },
795
+ "source": [
796
+ "## Convert audio to mono for NeMo combatibility"
797
+ ]
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "execution_count": null,
802
+ "metadata": {},
803
+ "outputs": [],
804
+ "source": [
805
+ "ROOT = os.getcwd()\n",
806
+ "temp_path = os.path.join(ROOT, \"temp_outputs\")\n",
807
+ "os.makedirs(temp_path, exist_ok=True)\n",
808
+ "torchaudio.save(\n",
809
+ " os.path.join(temp_path, \"mono_file.wav\"),\n",
810
+ " audio_waveform.cpu().unsqueeze(0).float(),\n",
811
+ " 16000,\n",
812
+ " channels_first=True,\n",
813
+ ")"
814
+ ]
815
+ },
816
+ {
817
+ "attachments": {},
818
+ "cell_type": "markdown",
819
+ "metadata": {
820
+ "id": "D1gkViCf2-CV"
821
+ },
822
+ "source": [
823
+ "## Speaker Diarization using NeMo MSDD Model\n",
824
+ "---\n",
825
+ "This code uses a model called Nvidia NeMo MSDD (Multi-scale Diarization Decoder) to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time."
826
+ ]
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "execution_count": null,
831
+ "metadata": {
832
+ "id": "C7jIpBCH02RL"
833
+ },
834
+ "outputs": [],
835
+ "source": [
836
+ "# Initialize NeMo MSDD diarization model\n",
837
+ "msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(\"cuda\")\n",
838
+ "msdd_model.diarize()\n",
839
+ "\n",
840
+ "del msdd_model\n",
841
+ "torch.cuda.empty_cache()"
842
+ ]
843
+ },
844
+ {
845
+ "attachments": {},
846
+ "cell_type": "markdown",
847
+ "metadata": {
848
+ "id": "NmkZYaDAEOAg"
849
+ },
850
+ "source": [
851
+ "## Mapping Spekers to Sentences According to Timestamps"
852
+ ]
853
+ },
854
+ {
855
+ "cell_type": "code",
856
+ "execution_count": null,
857
+ "metadata": {
858
+ "id": "E65LUGQe02zw"
859
+ },
860
+ "outputs": [],
861
+ "source": [
862
+ "# Reading timestamps <> Speaker Labels mapping\n",
863
+ "\n",
864
+ "speaker_ts = []\n",
865
+ "with open(os.path.join(temp_path, \"pred_rttms\", \"mono_file.rttm\"), \"r\") as f:\n",
866
+ " lines = f.readlines()\n",
867
+ " for line in lines:\n",
868
+ " line_list = line.split(\" \")\n",
869
+ " s = int(float(line_list[5]) * 1000)\n",
870
+ " e = s + int(float(line_list[8]) * 1000)\n",
871
+ " speaker_ts.append([s, e, int(line_list[11].split(\"_\")[-1])])\n",
872
+ "\n",
873
+ "wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, \"start\")"
874
+ ]
875
+ },
876
+ {
877
+ "attachments": {},
878
+ "cell_type": "markdown",
879
+ "metadata": {
880
+ "id": "8Ruxc8S1EXtW"
881
+ },
882
+ "source": [
883
+ "## Realligning Speech segments using Punctuation\n",
884
+ "---\n",
885
+ "\n",
886
+ "This code provides a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. It uses punctuation markings to determine the dominant speaker for each sentence in the transcription.\n",
887
+ "\n",
888
+ "```\n",
889
+ "Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are\n",
890
+ "Speaker B: going to suck, right? So it's actually it hits you on both sides.\n",
891
+ "```\n",
892
+ "\n",
893
+ "For example, if a sentence is split between two speakers, the code takes the mode of speaker labels for each word in the sentence, and uses that speaker label for the whole sentence. This can help to improve the accuracy of speaker diarization, especially in cases where the Whisper model may not take fine utterances like \"hmm\" and \"yeah\" into account, but the Diarization Model (Nemo) may include them, leading to inconsistent results.\n",
894
+ "\n",
895
+ "The code also handles cases where one speaker is giving a monologue while other speakers are making occasional comments in the background. It ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time. This provides a robust and reliable method for realigning speech segments to their respective speakers based on punctuation in the transcription."
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "execution_count": null,
901
+ "metadata": {
902
+ "id": "pgfC5hA41BXu"
903
+ },
904
+ "outputs": [],
905
+ "source": [
906
+ "if language in punct_model_langs:\n",
907
+ " # restoring punctuation in the transcript to help realign the sentences\n",
908
+ " punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
909
+ "\n",
910
+ " words_list = list(map(lambda x: x[\"word\"], wsm))\n",
911
+ "\n",
912
+ " labled_words = punct_model.predict(words_list,chunk_size=230)\n",
913
+ "\n",
914
+ " ending_puncts = \".?!\"\n",
915
+ " model_puncts = \".,;:!?\"\n",
916
+ "\n",
917
+ " # We don't want to punctuate U.S.A. with a period. Right?\n",
918
+ " is_acronym = lambda x: re.fullmatch(r\"\\b(?:[a-zA-Z]\\.){2,}\", x)\n",
919
+ "\n",
920
+ " for word_dict, labeled_tuple in zip(wsm, labled_words):\n",
921
+ " word = word_dict[\"word\"]\n",
922
+ " if (\n",
923
+ " word\n",
924
+ " and labeled_tuple[1] in ending_puncts\n",
925
+ " and (word[-1] not in model_puncts or is_acronym(word))\n",
926
+ " ):\n",
927
+ " word += labeled_tuple[1]\n",
928
+ " if word.endswith(\"..\"):\n",
929
+ " word = word.rstrip(\".\")\n",
930
+ " word_dict[\"word\"] = word\n",
931
+ "\n",
932
+ "else:\n",
933
+ " logging.warning(\n",
934
+ " f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
935
+ " )\n",
936
+ "\n",
937
+ "wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
938
+ "ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
939
+ ]
940
+ },
941
+ {
942
+ "attachments": {},
943
+ "cell_type": "markdown",
944
+ "metadata": {
945
+ "id": "vF2QAtLOFvwZ"
946
+ },
947
+ "source": [
948
+ "## Cleanup and Exporing the results"
949
+ ]
950
+ },
951
+ {
952
+ "cell_type": "code",
953
+ "execution_count": null,
954
+ "metadata": {
955
+ "id": "kFTyKI6B1MI0"
956
+ },
957
+ "outputs": [],
958
+ "source": [
959
+ "with open(f\"{os.path.splitext(audio_path)[0]}.txt\", \"w\", encoding=\"utf-8-sig\") as f:\n",
960
+ " get_speaker_aware_transcript(ssm, f)\n",
961
+ "\n",
962
+ "with open(f\"{os.path.splitext(audio_path)[0]}.srt\", \"w\", encoding=\"utf-8-sig\") as srt:\n",
963
+ " write_srt(ssm, srt)\n",
964
+ "\n",
965
+ "cleanup(temp_path)"
966
+ ]
967
+ }
968
+ ],
969
+ "metadata": {
970
+ "accelerator": "GPU",
971
+ "colab": {
972
+ "authorship_tag": "ABX9TyOyiQNkD+ROzss634BOsrSh",
973
+ "collapsed_sections": [
974
+ "eCmjcOc9yEtQ",
975
+ "jbsUt3SwyhjD"
976
+ ],
977
+ "include_colab_link": true,
978
+ "provenance": []
979
+ },
980
+ "gpuClass": "standard",
981
+ "kernelspec": {
982
+ "display_name": "Python 3",
983
+ "name": "python3"
984
+ },
985
+ "language_info": {
986
+ "codemirror_mode": {
987
+ "name": "ipython",
988
+ "version": 3
989
+ },
990
+ "file_extension": ".py",
991
+ "mimetype": "text/x-python",
992
+ "name": "python",
993
+ "nbconvert_exporter": "python",
994
+ "pygments_lexer": "ipython3",
995
+ "version": "3.10.12"
996
+ }
997
+ },
998
+ "nbformat": 4,
999
+ "nbformat_minor": 0
1000
+ }
whisper_diarization_main/diarize.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import re
5
+
6
+ import torch
7
+ import torchaudio
8
+ from ctc_forced_aligner import (
9
+ generate_emissions,
10
+ get_alignments,
11
+ get_spans,
12
+ load_alignment_model,
13
+ postprocess_results,
14
+ preprocess_text,
15
+ )
16
+ from deepmultilingualpunctuation import PunctuationModel
17
+ from nemo.collections.asr.models.msdd_models import NeuralDiarizer
18
+
19
+ from helpers import (
20
+ cleanup,
21
+ create_config,
22
+ get_realigned_ws_mapping_with_punctuation,
23
+ get_sentences_speaker_mapping,
24
+ get_speaker_aware_transcript,
25
+ get_words_speaker_mapping,
26
+ langs_to_iso,
27
+ punct_model_langs,
28
+ whisper_langs,
29
+ write_srt,
30
+ )
31
+ from transcription_helpers import transcribe_batched
32
+
33
+ mtypes = {"cpu": "int8", "cuda": "float16"}
34
+
35
+ # Initialize parser
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument(
38
+ "-a", "--audio", help="name of the target audio file", required=True
39
+ )
40
+ parser.add_argument(
41
+ "--no-stem",
42
+ action="store_false",
43
+ dest="stemming",
44
+ default=True,
45
+ help="Disables source separation."
46
+ "This helps with long files that don't contain a lot of music.",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--suppress_numerals",
51
+ action="store_true",
52
+ dest="suppress_numerals",
53
+ default=False,
54
+ help="Suppresses Numerical Digits."
55
+ "This helps the diarization accuracy but converts all digits into written text.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--whisper-model",
60
+ dest="model_name",
61
+ default="medium.en",
62
+ help="name of the Whisper model to use",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--batch-size",
67
+ type=int,
68
+ dest="batch_size",
69
+ default=8,
70
+ help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
71
+ )
72
+
73
+ parser.add_argument(
74
+ "--language",
75
+ type=str,
76
+ default=None,
77
+ choices=whisper_langs,
78
+ help="Language spoken in the audio, specify None to perform language detection",
79
+ )
80
+
81
+ parser.add_argument(
82
+ "--device",
83
+ dest="device",
84
+ default="cuda" if torch.cuda.is_available() else "cpu",
85
+ help="if you have a GPU use 'cuda', otherwise 'cpu'",
86
+ )
87
+
88
+ args = parser.parse_args()
89
+
90
+ if args.stemming:
91
+ # Isolate vocals from the rest of the audio
92
+
93
+ return_code = os.system(
94
+ f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
95
+ )
96
+
97
+ if return_code != 0:
98
+ logging.warning(
99
+ "Source splitting failed, using original audio file. Use --no-stem argument to disable it."
100
+ )
101
+ vocal_target = args.audio
102
+ else:
103
+ vocal_target = os.path.join(
104
+ "temp_outputs",
105
+ "htdemucs",
106
+ os.path.splitext(os.path.basename(args.audio))[0],
107
+ "vocals.wav",
108
+ )
109
+ else:
110
+ vocal_target = args.audio
111
+
112
+
113
+ # Transcribe the audio file
114
+
115
+ whisper_results, language, audio_waveform = transcribe_batched(
116
+ vocal_target,
117
+ args.language,
118
+ args.batch_size,
119
+ args.model_name,
120
+ mtypes[args.device],
121
+ args.suppress_numerals,
122
+ args.device,
123
+ )
124
+
125
+ # Forced Alignment
126
+ alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(
127
+ args.device,
128
+ dtype=torch.float16 if args.device == "cuda" else torch.float32,
129
+ )
130
+
131
+ audio_waveform = (
132
+ torch.from_numpy(audio_waveform)
133
+ .to(alignment_model.dtype)
134
+ .to(alignment_model.device)
135
+ )
136
+ emissions, stride = generate_emissions(
137
+ alignment_model, audio_waveform, batch_size=args.batch_size
138
+ )
139
+
140
+ del alignment_model
141
+ torch.cuda.empty_cache()
142
+
143
+ full_transcript = "".join(segment["text"] for segment in whisper_results)
144
+
145
+ tokens_starred, text_starred = preprocess_text(
146
+ full_transcript,
147
+ romanize=True,
148
+ language=langs_to_iso[language],
149
+ )
150
+
151
+ segments, scores, blank_id = get_alignments(
152
+ emissions,
153
+ tokens_starred,
154
+ alignment_dictionary,
155
+ )
156
+
157
+ spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))
158
+
159
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
160
+
161
+
162
+ # convert audio to mono for NeMo combatibility
163
+ ROOT = os.getcwd()
164
+ temp_path = os.path.join(ROOT, "temp_outputs")
165
+ os.makedirs(temp_path, exist_ok=True)
166
+ torchaudio.save(
167
+ os.path.join(temp_path, "mono_file.wav"),
168
+ audio_waveform.cpu().unsqueeze(0).float(),
169
+ 16000,
170
+ channels_first=True,
171
+ )
172
+
173
+
174
+ # Initialize NeMo MSDD diarization model
175
+ msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device)
176
+ msdd_model.diarize()
177
+
178
+ del msdd_model
179
+ torch.cuda.empty_cache()
180
+
181
+ # Reading timestamps <> Speaker Labels mapping
182
+
183
+
184
+ speaker_ts = []
185
+ with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
186
+ lines = f.readlines()
187
+ for line in lines:
188
+ line_list = line.split(" ")
189
+ s = int(float(line_list[5]) * 1000)
190
+ e = s + int(float(line_list[8]) * 1000)
191
+ speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
192
+
193
+ wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
194
+
195
+ if language in punct_model_langs:
196
+ # restoring punctuation in the transcript to help realign the sentences
197
+ punct_model = PunctuationModel(model="kredor/punctuate-all")
198
+
199
+ words_list = list(map(lambda x: x["word"], wsm))
200
+
201
+ labled_words = punct_model.predict(words_list, chunk_size=230)
202
+
203
+ ending_puncts = ".?!"
204
+ model_puncts = ".,;:!?"
205
+
206
+ # We don't want to punctuate U.S.A. with a period. Right?
207
+ is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
208
+
209
+ for word_dict, labeled_tuple in zip(wsm, labled_words):
210
+ word = word_dict["word"]
211
+ if (
212
+ word
213
+ and labeled_tuple[1] in ending_puncts
214
+ and (word[-1] not in model_puncts or is_acronym(word))
215
+ ):
216
+ word += labeled_tuple[1]
217
+ if word.endswith(".."):
218
+ word = word.rstrip(".")
219
+ word_dict["word"] = word
220
+
221
+ else:
222
+ logging.warning(
223
+ f"Punctuation restoration is not available for {language} language. Using the original punctuation."
224
+ )
225
+
226
+ wsm = get_realigned_ws_mapping_with_punctuation(wsm)
227
+ ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
228
+
229
+ with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
230
+ get_speaker_aware_transcript(ssm, f)
231
+
232
+ with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt:
233
+ write_srt(ssm, srt)
234
+
235
+ cleanup(temp_path)
whisper_diarization_main/diarize_parallel.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import re
5
+ import subprocess
6
+
7
+ import torch
8
+ from ctc_forced_aligner import (
9
+ generate_emissions,
10
+ get_alignments,
11
+ get_spans,
12
+ load_alignment_model,
13
+ postprocess_results,
14
+ preprocess_text,
15
+ )
16
+ from deepmultilingualpunctuation import PunctuationModel
17
+
18
+ from helpers import (
19
+ cleanup,
20
+ get_realigned_ws_mapping_with_punctuation,
21
+ get_sentences_speaker_mapping,
22
+ get_speaker_aware_transcript,
23
+ get_words_speaker_mapping,
24
+ langs_to_iso,
25
+ punct_model_langs,
26
+ whisper_langs,
27
+ write_srt,
28
+ )
29
+ from transcription_helpers import transcribe_batched
30
+
31
+ mtypes = {"cpu": "int8", "cuda": "float16"}
32
+
33
+ # Initialize parser
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "-a", "--audio", help="name of the target audio file", required=True
37
+ )
38
+ parser.add_argument(
39
+ "--no-stem",
40
+ action="store_false",
41
+ dest="stemming",
42
+ default=True,
43
+ help="Disables source separation."
44
+ "This helps with long files that don't contain a lot of music.",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--suppress_numerals",
49
+ action="store_true",
50
+ dest="suppress_numerals",
51
+ default=False,
52
+ help="Suppresses Numerical Digits."
53
+ "This helps the diarization accuracy but converts all digits into written text.",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--whisper-model",
58
+ dest="model_name",
59
+ default="medium.en",
60
+ help="name of the Whisper model to use",
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--batch-size",
65
+ type=int,
66
+ dest="batch_size",
67
+ default=8,
68
+ help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--language",
73
+ type=str,
74
+ default=None,
75
+ choices=whisper_langs,
76
+ help="Language spoken in the audio, specify None to perform language detection",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--device",
81
+ dest="device",
82
+ default="cuda" if torch.cuda.is_available() else "cpu",
83
+ help="if you have a GPU use 'cuda', otherwise 'cpu'",
84
+ )
85
+
86
+ args = parser.parse_args()
87
+
88
+ if args.stemming:
89
+ # Isolate vocals from the rest of the audio
90
+
91
+ return_code = os.system(
92
+ f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
93
+ )
94
+
95
+ if return_code != 0:
96
+ logging.warning(
97
+ "Source splitting failed, using original audio file. Use --no-stem argument to disable it."
98
+ )
99
+ vocal_target = args.audio
100
+ else:
101
+ vocal_target = os.path.join(
102
+ "temp_outputs",
103
+ "htdemucs",
104
+ os.path.splitext(os.path.basename(args.audio))[0],
105
+ "vocals.wav",
106
+ )
107
+ else:
108
+ vocal_target = args.audio
109
+
110
+ logging.info("Starting Nemo process with vocal_target: ", vocal_target)
111
+ nemo_process = subprocess.Popen(
112
+ ["python3", "nemo_process.py", "-a", vocal_target, "--device", args.device],
113
+ )
114
+ # Transcribe the audio file
115
+ whisper_results, language, audio_waveform = transcribe_batched(
116
+ vocal_target,
117
+ args.language,
118
+ args.batch_size,
119
+ args.model_name,
120
+ mtypes[args.device],
121
+ args.suppress_numerals,
122
+ args.device,
123
+ )
124
+
125
+ # Forced Alignment
126
+ alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(
127
+ args.device,
128
+ dtype=torch.float16 if args.device == "cuda" else torch.float32,
129
+ )
130
+
131
+ audio_waveform = (
132
+ torch.from_numpy(audio_waveform)
133
+ .to(alignment_model.dtype)
134
+ .to(alignment_model.device)
135
+ )
136
+ emissions, stride = generate_emissions(
137
+ alignment_model, audio_waveform, batch_size=args.batch_size
138
+ )
139
+
140
+ del alignment_model
141
+ torch.cuda.empty_cache()
142
+
143
+ full_transcript = "".join(segment["text"] for segment in whisper_results)
144
+
145
+ tokens_starred, text_starred = preprocess_text(
146
+ full_transcript,
147
+ romanize=True,
148
+ language=langs_to_iso[language],
149
+ )
150
+
151
+ segments, scores, blank_id = get_alignments(
152
+ emissions,
153
+ tokens_starred,
154
+ alignment_dictionary,
155
+ )
156
+
157
+ spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))
158
+
159
+ word_timestamps = postprocess_results(text_starred, spans, stride, scores)
160
+
161
+ # Reading timestamps <> Speaker Labels mapping
162
+ nemo_process.communicate()
163
+ ROOT = os.getcwd()
164
+ temp_path = os.path.join(ROOT, "temp_outputs")
165
+
166
+ speaker_ts = []
167
+ with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
168
+ lines = f.readlines()
169
+ for line in lines:
170
+ line_list = line.split(" ")
171
+ s = int(float(line_list[5]) * 1000)
172
+ e = s + int(float(line_list[8]) * 1000)
173
+ speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
174
+
175
+ wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
176
+
177
+ if language in punct_model_langs:
178
+ # restoring punctuation in the transcript to help realign the sentences
179
+ punct_model = PunctuationModel(model="kredor/punctuate-all")
180
+
181
+ words_list = list(map(lambda x: x["word"], wsm))
182
+
183
+ labled_words = punct_model.predict(words_list, chunk_size=230)
184
+
185
+ ending_puncts = ".?!"
186
+ model_puncts = ".,;:!?"
187
+
188
+ # We don't want to punctuate U.S.A. with a period. Right?
189
+ is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
190
+
191
+ for word_dict, labeled_tuple in zip(wsm, labled_words):
192
+ word = word_dict["word"]
193
+ if (
194
+ word
195
+ and labeled_tuple[1] in ending_puncts
196
+ and (word[-1] not in model_puncts or is_acronym(word))
197
+ ):
198
+ word += labeled_tuple[1]
199
+ if word.endswith(".."):
200
+ word = word.rstrip(".")
201
+ word_dict["word"] = word
202
+
203
+ else:
204
+ logging.warning(
205
+ f"Punctuation restoration is not available for {language} language. Using the original punctuation."
206
+ )
207
+
208
+ wsm = get_realigned_ws_mapping_with_punctuation(wsm)
209
+ ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
210
+
211
+ with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
212
+ get_speaker_aware_transcript(ssm, f)
213
+
214
+ with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt:
215
+ write_srt(ssm, srt)
216
+
217
+ cleanup(temp_path)
whisper_diarization_main/helpers.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import shutil
5
+
6
+ import nltk
7
+ import wget
8
+ from omegaconf import OmegaConf
9
+ from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
10
+ from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE
11
+
12
+ punct_model_langs = [
13
+ "en",
14
+ "fr",
15
+ "de",
16
+ "es",
17
+ "it",
18
+ "nl",
19
+ "pt",
20
+ "bg",
21
+ "pl",
22
+ "cs",
23
+ "sk",
24
+ "sl",
25
+ ]
26
+ wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
27
+ DEFAULT_ALIGN_MODELS_HF.keys()
28
+ )
29
+
30
+ whisper_langs = sorted(LANGUAGES.keys()) + sorted(
31
+ [k.title() for k in TO_LANGUAGE_CODE.keys()]
32
+ )
33
+
34
+ langs_to_iso = {
35
+ "aa": "aar",
36
+ "ab": "abk",
37
+ "ae": "ave",
38
+ "af": "afr",
39
+ "ak": "aka",
40
+ "am": "amh",
41
+ "an": "arg",
42
+ "ar": "ara",
43
+ "as": "asm",
44
+ "av": "ava",
45
+ "ay": "aym",
46
+ "az": "aze",
47
+ "ba": "bak",
48
+ "be": "bel",
49
+ "bg": "bul",
50
+ "bh": "bih",
51
+ "bi": "bis",
52
+ "bm": "bam",
53
+ "bn": "ben",
54
+ "bo": "tib",
55
+ "br": "bre",
56
+ "bs": "bos",
57
+ "ca": "cat",
58
+ "ce": "che",
59
+ "ch": "cha",
60
+ "co": "cos",
61
+ "cr": "cre",
62
+ "cs": "cze",
63
+ "cu": "chu",
64
+ "cv": "chv",
65
+ "cy": "wel",
66
+ "da": "dan",
67
+ "de": "ger",
68
+ "dv": "div",
69
+ "dz": "dzo",
70
+ "ee": "ewe",
71
+ "el": "gre",
72
+ "en": "eng",
73
+ "eo": "epo",
74
+ "es": "spa",
75
+ "et": "est",
76
+ "eu": "baq",
77
+ "fa": "per",
78
+ "ff": "ful",
79
+ "fi": "fin",
80
+ "fj": "fij",
81
+ "fo": "fao",
82
+ "fr": "fre",
83
+ "fy": "fry",
84
+ "ga": "gle",
85
+ "gd": "gla",
86
+ "gl": "glg",
87
+ "gn": "grn",
88
+ "gu": "guj",
89
+ "gv": "glv",
90
+ "ha": "hau",
91
+ "he": "heb",
92
+ "hi": "hin",
93
+ "ho": "hmo",
94
+ "hr": "hrv",
95
+ "ht": "hat",
96
+ "hu": "hun",
97
+ "hy": "arm",
98
+ "hz": "her",
99
+ "ia": "ina",
100
+ "id": "ind",
101
+ "ie": "ile",
102
+ "ig": "ibo",
103
+ "ii": "iii",
104
+ "ik": "ipk",
105
+ "io": "ido",
106
+ "is": "ice",
107
+ "it": "ita",
108
+ "iu": "iku",
109
+ "ja": "jpn",
110
+ "jv": "jav",
111
+ "ka": "geo",
112
+ "kg": "kon",
113
+ "ki": "kik",
114
+ "kj": "kua",
115
+ "kk": "kaz",
116
+ "kl": "kal",
117
+ "km": "khm",
118
+ "kn": "kan",
119
+ "ko": "kor",
120
+ "kr": "kau",
121
+ "ks": "kas",
122
+ "ku": "kur",
123
+ "kv": "kom",
124
+ "kw": "cor",
125
+ "ky": "kir",
126
+ "la": "lat",
127
+ "lb": "ltz",
128
+ "lg": "lug",
129
+ "li": "lim",
130
+ "ln": "lin",
131
+ "lo": "lao",
132
+ "lt": "lit",
133
+ "lu": "lub",
134
+ "lv": "lav",
135
+ "mg": "mlg",
136
+ "mh": "mah",
137
+ "mi": "mao",
138
+ "mk": "mac",
139
+ "ml": "mal",
140
+ "mn": "mon",
141
+ "mr": "mar",
142
+ "ms": "may",
143
+ "mt": "mlt",
144
+ "my": "bur",
145
+ "na": "nau",
146
+ "nb": "nob",
147
+ "nd": "nde",
148
+ "ne": "nep",
149
+ "ng": "ndo",
150
+ "nl": "dut",
151
+ "nn": "nno",
152
+ "no": "nor",
153
+ "nr": "nbl",
154
+ "nv": "nav",
155
+ "ny": "nya",
156
+ "oc": "oci",
157
+ "oj": "oji",
158
+ "om": "orm",
159
+ "or": "ori",
160
+ "os": "oss",
161
+ "pa": "pan",
162
+ "pi": "pli",
163
+ "pl": "pol",
164
+ "ps": "pus",
165
+ "pt": "por",
166
+ "qu": "que",
167
+ "rm": "roh",
168
+ "rn": "run",
169
+ "ro": "rum",
170
+ "ru": "rus",
171
+ "rw": "kin",
172
+ "sa": "san",
173
+ "sc": "srd",
174
+ "sd": "snd",
175
+ "se": "sme",
176
+ "sg": "sag",
177
+ "si": "sin",
178
+ "sk": "slo",
179
+ "sl": "slv",
180
+ "sm": "smo",
181
+ "sn": "sna",
182
+ "so": "som",
183
+ "sq": "alb",
184
+ "sr": "srp",
185
+ "ss": "ssw",
186
+ "st": "sot",
187
+ "su": "sun",
188
+ "sv": "swe",
189
+ "sw": "swa",
190
+ "ta": "tam",
191
+ "te": "tel",
192
+ "tg": "tgk",
193
+ "th": "tha",
194
+ "ti": "tir",
195
+ "tk": "tuk",
196
+ "tl": "tgl",
197
+ "tn": "tsn",
198
+ "to": "ton",
199
+ "tr": "tur",
200
+ "ts": "tso",
201
+ "tt": "tat",
202
+ "tw": "twi",
203
+ "ty": "tah",
204
+ "ug": "uig",
205
+ "uk": "ukr",
206
+ "ur": "urd",
207
+ "uz": "uzb",
208
+ "ve": "ven",
209
+ "vi": "vie",
210
+ "vo": "vol",
211
+ "wa": "wln",
212
+ "wo": "wol",
213
+ "xh": "xho",
214
+ "yi": "yid",
215
+ "yo": "yor",
216
+ "za": "zha",
217
+ "zh": "chi",
218
+ "zu": "zul",
219
+ }
220
+
221
+
222
+ def create_config(output_dir):
223
+ DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file
224
+ CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs"
225
+ CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
226
+ MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME)
227
+ if not os.path.exists(MODEL_CONFIG_PATH):
228
+ os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True)
229
+ CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
230
+ MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH)
231
+
232
+ config = OmegaConf.load(MODEL_CONFIG_PATH)
233
+
234
+ data_dir = os.path.join(output_dir, "data")
235
+ os.makedirs(data_dir, exist_ok=True)
236
+
237
+ meta = {
238
+ "audio_filepath": os.path.join(output_dir, "mono_file.wav"),
239
+ "offset": 0,
240
+ "duration": None,
241
+ "label": "infer",
242
+ "text": "-",
243
+ "rttm_filepath": None,
244
+ "uem_filepath": None,
245
+ }
246
+ with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
247
+ json.dump(meta, fp)
248
+ fp.write("\n")
249
+
250
+ pretrained_vad = "vad_multilingual_marblenet"
251
+ pretrained_speaker_model = "titanet_large"
252
+ config.num_workers = 0
253
+ config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
254
+ config.diarizer.out_dir = (
255
+ output_dir # Directory to store intermediate files and prediction outputs
256
+ )
257
+
258
+ config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
259
+ config.diarizer.oracle_vad = (
260
+ False # compute VAD provided with model_path to vad config
261
+ )
262
+ config.diarizer.clustering.parameters.oracle_num_speakers = False
263
+
264
+ # Here, we use our in-house pretrained NeMo VAD model
265
+ config.diarizer.vad.model_path = pretrained_vad
266
+ config.diarizer.vad.parameters.onset = 0.8
267
+ config.diarizer.vad.parameters.offset = 0.6
268
+ config.diarizer.vad.parameters.pad_offset = -0.05
269
+ config.diarizer.msdd_model.model_path = (
270
+ "diar_msdd_telephonic" # Telephonic speaker diarization model
271
+ )
272
+
273
+ return config
274
+
275
+
276
+ def get_word_ts_anchor(s, e, option="start"):
277
+ if option == "end":
278
+ return e
279
+ elif option == "mid":
280
+ return (s + e) / 2
281
+ return s
282
+
283
+
284
+ def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
285
+ s, e, sp = spk_ts[0]
286
+ wrd_pos, turn_idx = 0, 0
287
+ wrd_spk_mapping = []
288
+ for wrd_dict in wrd_ts:
289
+ ws, we, wrd = (
290
+ int(wrd_dict["start"] * 1000),
291
+ int(wrd_dict["end"] * 1000),
292
+ wrd_dict["text"],
293
+ )
294
+ wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
295
+ while wrd_pos > float(e):
296
+ turn_idx += 1
297
+ turn_idx = min(turn_idx, len(spk_ts) - 1)
298
+ s, e, sp = spk_ts[turn_idx]
299
+ if turn_idx == len(spk_ts) - 1:
300
+ e = get_word_ts_anchor(ws, we, option="end")
301
+ wrd_spk_mapping.append(
302
+ {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
303
+ )
304
+ return wrd_spk_mapping
305
+
306
+
307
+ sentence_ending_punctuations = ".?!"
308
+
309
+
310
+ def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
311
+ is_word_sentence_end = (
312
+ lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
313
+ )
314
+ left_idx = word_idx
315
+ while (
316
+ left_idx > 0
317
+ and word_idx - left_idx < max_words
318
+ and speaker_list[left_idx - 1] == speaker_list[left_idx]
319
+ and not is_word_sentence_end(left_idx - 1)
320
+ ):
321
+ left_idx -= 1
322
+
323
+ return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
324
+
325
+
326
+ def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
327
+ is_word_sentence_end = (
328
+ lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
329
+ )
330
+ right_idx = word_idx
331
+ while (
332
+ right_idx < len(word_list) - 1
333
+ and right_idx - word_idx < max_words
334
+ and not is_word_sentence_end(right_idx)
335
+ ):
336
+ right_idx += 1
337
+
338
+ return (
339
+ right_idx
340
+ if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
341
+ else -1
342
+ )
343
+
344
+
345
+ def get_realigned_ws_mapping_with_punctuation(
346
+ word_speaker_mapping, max_words_in_sentence=50
347
+ ):
348
+ is_word_sentence_end = (
349
+ lambda x: x >= 0
350
+ and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
351
+ )
352
+ wsp_len = len(word_speaker_mapping)
353
+
354
+ words_list, speaker_list = [], []
355
+ for k, line_dict in enumerate(word_speaker_mapping):
356
+ word, speaker = line_dict["word"], line_dict["speaker"]
357
+ words_list.append(word)
358
+ speaker_list.append(speaker)
359
+
360
+ k = 0
361
+ while k < len(word_speaker_mapping):
362
+ line_dict = word_speaker_mapping[k]
363
+ if (
364
+ k < wsp_len - 1
365
+ and speaker_list[k] != speaker_list[k + 1]
366
+ and not is_word_sentence_end(k)
367
+ ):
368
+ left_idx = get_first_word_idx_of_sentence(
369
+ k, words_list, speaker_list, max_words_in_sentence
370
+ )
371
+ right_idx = (
372
+ get_last_word_idx_of_sentence(
373
+ k, words_list, max_words_in_sentence - k + left_idx - 1
374
+ )
375
+ if left_idx > -1
376
+ else -1
377
+ )
378
+ if min(left_idx, right_idx) == -1:
379
+ k += 1
380
+ continue
381
+
382
+ spk_labels = speaker_list[left_idx : right_idx + 1]
383
+ mod_speaker = max(set(spk_labels), key=spk_labels.count)
384
+ if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
385
+ k += 1
386
+ continue
387
+
388
+ speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
389
+ right_idx - left_idx + 1
390
+ )
391
+ k = right_idx
392
+
393
+ k += 1
394
+
395
+ k, realigned_list = 0, []
396
+ while k < len(word_speaker_mapping):
397
+ line_dict = word_speaker_mapping[k].copy()
398
+ line_dict["speaker"] = speaker_list[k]
399
+ realigned_list.append(line_dict)
400
+ k += 1
401
+
402
+ return realigned_list
403
+
404
+
405
+ def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
406
+ sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
407
+ s, e, spk = spk_ts[0]
408
+ prev_spk = spk
409
+
410
+ snts = []
411
+ snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}
412
+
413
+ for wrd_dict in word_speaker_mapping:
414
+ wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
415
+ s, e = wrd_dict["start_time"], wrd_dict["end_time"]
416
+ if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
417
+ snts.append(snt)
418
+ snt = {
419
+ "speaker": f"Speaker {spk}",
420
+ "start_time": s,
421
+ "end_time": e,
422
+ "text": "",
423
+ }
424
+ else:
425
+ snt["end_time"] = e
426
+ snt["text"] += wrd + " "
427
+ prev_spk = spk
428
+
429
+ snts.append(snt)
430
+ return snts
431
+
432
+
433
+ def get_speaker_aware_transcript(sentences_speaker_mapping, f):
434
+ previous_speaker = sentences_speaker_mapping[0]["speaker"]
435
+ f.write(f"{previous_speaker}: ")
436
+
437
+ for sentence_dict in sentences_speaker_mapping:
438
+ speaker = sentence_dict["speaker"]
439
+ sentence = sentence_dict["text"]
440
+
441
+ # If this speaker doesn't match the previous one, start a new paragraph
442
+ if speaker != previous_speaker:
443
+ f.write(f"\n\n{speaker}: ")
444
+ previous_speaker = speaker
445
+
446
+ # No matter what, write the current sentence
447
+ f.write(sentence + " ")
448
+
449
+
450
+ def format_timestamp(
451
+ milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
452
+ ):
453
+ assert milliseconds >= 0, "non-negative timestamp expected"
454
+
455
+ hours = milliseconds // 3_600_000
456
+ milliseconds -= hours * 3_600_000
457
+
458
+ minutes = milliseconds // 60_000
459
+ milliseconds -= minutes * 60_000
460
+
461
+ seconds = milliseconds // 1_000
462
+ milliseconds -= seconds * 1_000
463
+
464
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
465
+ return (
466
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
467
+ )
468
+
469
+
470
+ def write_srt(transcript, file):
471
+ """
472
+ Write a transcript to a file in SRT format.
473
+
474
+ """
475
+ for i, segment in enumerate(transcript, start=1):
476
+ # write srt lines
477
+ print(
478
+ f"{i}\n"
479
+ f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
480
+ f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
481
+ f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
482
+ file=file,
483
+ flush=True,
484
+ )
485
+
486
+
487
+ def find_numeral_symbol_tokens(tokenizer):
488
+ numeral_symbol_tokens = [
489
+ -1,
490
+ ]
491
+ for token, token_id in tokenizer.get_vocab().items():
492
+ has_numeral_symbol = any(c in "0123456789%$£" for c in token)
493
+ if has_numeral_symbol:
494
+ numeral_symbol_tokens.append(token_id)
495
+ return numeral_symbol_tokens
496
+
497
+
498
+ def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
499
+ # if current word is the last word
500
+ if current_word_index == len(word_timestamps) - 1:
501
+ return word_timestamps[current_word_index]["start"]
502
+
503
+ next_word_index = current_word_index + 1
504
+ while current_word_index < len(word_timestamps) - 1:
505
+ if word_timestamps[next_word_index].get("start") is None:
506
+ # if next word doesn't have a start timestamp
507
+ # merge it with the current word and delete it
508
+ word_timestamps[current_word_index]["word"] += (
509
+ " " + word_timestamps[next_word_index]["word"]
510
+ )
511
+
512
+ word_timestamps[next_word_index]["word"] = None
513
+ next_word_index += 1
514
+ if next_word_index == len(word_timestamps):
515
+ return final_timestamp
516
+
517
+ else:
518
+ return word_timestamps[next_word_index]["start"]
519
+
520
+
521
+ def filter_missing_timestamps(
522
+ word_timestamps, initial_timestamp=0, final_timestamp=None
523
+ ):
524
+ # handle the first and last word
525
+ if word_timestamps[0].get("start") is None:
526
+ word_timestamps[0]["start"] = (
527
+ initial_timestamp if initial_timestamp is not None else 0
528
+ )
529
+ word_timestamps[0]["end"] = _get_next_start_timestamp(
530
+ word_timestamps, 0, final_timestamp
531
+ )
532
+
533
+ result = [
534
+ word_timestamps[0],
535
+ ]
536
+
537
+ for i, ws in enumerate(word_timestamps[1:], start=1):
538
+ # if ws doesn't have a start and end
539
+ # use the previous end as start and next start as end
540
+ if ws.get("start") is None and ws.get("word") is not None:
541
+ ws["start"] = word_timestamps[i - 1]["end"]
542
+ ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)
543
+
544
+ if ws["word"] is not None:
545
+ result.append(ws)
546
+ return result
547
+
548
+
549
+ def cleanup(path: str):
550
+ """path could either be relative or absolute."""
551
+ # check if file or directory exists
552
+ if os.path.isfile(path) or os.path.islink(path):
553
+ # remove file
554
+ os.remove(path)
555
+ elif os.path.isdir(path):
556
+ # remove directory and all its content
557
+ shutil.rmtree(path)
558
+ else:
559
+ raise ValueError("Path {} is not a file or dir.".format(path))
560
+
561
+
562
+ def process_language_arg(language: str, model_name: str):
563
+ """
564
+ Process the language argument to make sure it's valid and convert language names to language codes.
565
+ """
566
+ if language is not None:
567
+ language = language.lower()
568
+ if language not in LANGUAGES:
569
+ if language in TO_LANGUAGE_CODE:
570
+ language = TO_LANGUAGE_CODE[language]
571
+ else:
572
+ raise ValueError(f"Unsupported language: {language}")
573
+
574
+ if model_name.endswith(".en") and language != "en":
575
+ if language is not None:
576
+ logging.warning(
577
+ f"{model_name} is an English-only model but received '{language}'; using English instead."
578
+ )
579
+ language = "en"
580
+ return language
whisper_diarization_main/nemo_msdd_configs/diar_infer_general.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
2
+ # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
3
+ # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
4
+ # The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set.
5
+ # An example line in an input manifest file (`.json` format):
6
+ # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
7
+ name: &name "ClusterDiarizer"
8
+
9
+ num_workers: 1
10
+ sample_rate: 16000
11
+ batch_size: 64
12
+ device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
13
+ verbose: True # enable additional logging
14
+
15
+ diarizer:
16
+ manifest_filepath: ???
17
+ out_dir: ???
18
+ oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
19
+ collar: 0.25 # Collar value for scoring
20
+ ignore_overlap: True # Consider or ignore overlap segments while scoring
21
+
22
+ vad:
23
+ model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
24
+ external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
25
+
26
+ parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets
27
+ window_length_in_sec: 0.63 # Window length in sec for VAD context input
28
+ shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction
29
+ smoothing: False # False or type of smoothing method (eg: median)
30
+ overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
31
+ onset: 0.5 # Onset threshold for detecting the beginning and end of a speech
32
+ offset: 0.3 # Offset threshold for detecting the end of a speech
33
+ pad_onset: 0.2 # Adding durations before each speech segment
34
+ pad_offset: 0.2 # Adding durations after each speech segment
35
+ min_duration_on: 0.5 # Threshold for small non_speech deletion
36
+ min_duration_off: 0.5 # Threshold for short speech segment deletion
37
+ filter_speech_first: True
38
+
39
+ speaker_embeddings:
40
+ model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
41
+ parameters:
42
+ window_length_in_sec: [1.9,1.2,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
43
+ shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
44
+ multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
45
+ save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
46
+
47
+ clustering:
48
+ parameters:
49
+ oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
50
+ max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
51
+ enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
52
+ max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
53
+ sparse_search_volume: 10 # The higher the number, the more values will be examined with more time.
54
+ maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
55
+ chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
56
+ embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
57
+
58
+
59
+ msdd_model:
60
+ model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
61
+ parameters:
62
+ use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
63
+ infer_batch_size: 25 # Batch size for MSDD inference.
64
+ sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
65
+ seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
66
+ split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
67
+ diar_window_length: 50 # The length of split short sequence when split_infer is True.
68
+ overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
69
+
70
+ asr:
71
+ model_path: null # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
72
+ parameters:
73
+ asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
74
+ asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
75
+ asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
76
+ decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
77
+ word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
78
+ word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
79
+ fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
80
+ colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
81
+ print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
82
+ break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
83
+
84
+ ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
85
+ pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
86
+ beam_width: 32
87
+ alpha: 0.5
88
+ beta: 2.5
89
+
90
+ realigning_lm_parameters: # Experimental feature
91
+ arpa_language_model: null # Provide a KenLM language model in .arpa format.
92
+ min_number_of_words: 3 # Min number of words for the left context.
93
+ max_number_of_words: 10 # Max number of words for the right context.
94
+ logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
95
+
whisper_diarization_main/nemo_msdd_configs/diar_infer_meeting.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
2
+ # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
3
+ # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
4
+ # The configurations in this YAML file is suitable for 3~5 speakers participating in a meeting and may not show the best performance on other types of dialogues.
5
+ # An example line in an input manifest file (`.json` format):
6
+ # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
7
+ name: &name "ClusterDiarizer"
8
+
9
+ num_workers: 1
10
+ sample_rate: 16000
11
+ batch_size: 64
12
+ device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
13
+ verbose: True # enable additional logging
14
+
15
+ diarizer:
16
+ manifest_filepath: ???
17
+ out_dir: ???
18
+ oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
19
+ collar: 0.25 # Collar value for scoring
20
+ ignore_overlap: True # Consider or ignore overlap segments while scoring
21
+
22
+ vad:
23
+ model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
24
+ external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
25
+
26
+ parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set)
27
+ window_length_in_sec: 0.63 # Window length in sec for VAD context input
28
+ shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction
29
+ smoothing: False # False or type of smoothing method (eg: median)
30
+ overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
31
+ onset: 0.9 # Onset threshold for detecting the beginning and end of a speech
32
+ offset: 0.5 # Offset threshold for detecting the end of a speech
33
+ pad_onset: 0 # Adding durations before each speech segment
34
+ pad_offset: 0 # Adding durations after each speech segment
35
+ min_duration_on: 0 # Threshold for small non_speech deletion
36
+ min_duration_off: 0.6 # Threshold for short speech segment deletion
37
+ filter_speech_first: True
38
+
39
+ speaker_embeddings:
40
+ model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
41
+ parameters:
42
+ window_length_in_sec: [3.0,2.5,2.0,1.5,1.0,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
43
+ shift_length_in_sec: [1.5,1.25,1.0,0.75,0.5,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
44
+ multiscale_weights: [1,1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
45
+ save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
46
+
47
+ clustering:
48
+ parameters:
49
+ oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
50
+ max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
51
+ enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
52
+ max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
53
+ sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
54
+ maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
55
+ chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
56
+ embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
57
+
58
+ msdd_model:
59
+ model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
60
+ parameters:
61
+ use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
62
+ infer_batch_size: 25 # Batch size for MSDD inference.
63
+ sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
64
+ seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
65
+ split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
66
+ diar_window_length: 50 # The length of split short sequence when split_infer is True.
67
+ overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
68
+
69
+ asr:
70
+ model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
71
+ parameters:
72
+ asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
73
+ asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
74
+ asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
75
+ decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
76
+ word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
77
+ word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
78
+ fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
79
+ colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
80
+ print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
81
+ break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
82
+
83
+ ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
84
+ pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
85
+ beam_width: 32
86
+ alpha: 0.5
87
+ beta: 2.5
88
+
89
+ realigning_lm_parameters: # Experimental feature
90
+ arpa_language_model: null # Provide a KenLM language model in .arpa format.
91
+ min_number_of_words: 3 # Min number of words for the left context.
92
+ max_number_of_words: 10 # Max number of words for the right context.
93
+ logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
94
+
whisper_diarization_main/nemo_msdd_configs/diar_infer_telephonic.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
2
+ # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
3
+ # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
4
+ # The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues.
5
+ # An example line in an input manifest file (`.json` format):
6
+ # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
7
+ name: &name "ClusterDiarizer"
8
+
9
+ num_workers: 1
10
+ sample_rate: 16000
11
+ batch_size: 64
12
+ device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
13
+ verbose: True # enable additional logging
14
+
15
+ diarizer:
16
+ manifest_filepath: ???
17
+ out_dir: ???
18
+ oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
19
+ collar: 0.25 # Collar value for scoring
20
+ ignore_overlap: True # Consider or ignore overlap segments while scoring
21
+
22
+ vad:
23
+ model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
24
+ external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
25
+
26
+ parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set)
27
+ window_length_in_sec: 0.15 # Window length in sec for VAD context input
28
+ shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction
29
+ smoothing: "median" # False or type of smoothing method (eg: median)
30
+ overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
31
+ onset: 0.1 # Onset threshold for detecting the beginning and end of a speech
32
+ offset: 0.1 # Offset threshold for detecting the end of a speech
33
+ pad_onset: 0.1 # Adding durations before each speech segment
34
+ pad_offset: 0 # Adding durations after each speech segment
35
+ min_duration_on: 0 # Threshold for small non_speech deletion
36
+ min_duration_off: 0.2 # Threshold for short speech segment deletion
37
+ filter_speech_first: True
38
+
39
+ speaker_embeddings:
40
+ model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
41
+ parameters:
42
+ window_length_in_sec: [1.5,1.25,1.0,0.75,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
43
+ shift_length_in_sec: [0.75,0.625,0.5,0.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
44
+ multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
45
+ save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
46
+
47
+ clustering:
48
+ parameters:
49
+ oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
50
+ max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
51
+ enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
52
+ max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
53
+ sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
54
+ maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
55
+ chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
56
+ embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
57
+
58
+ msdd_model:
59
+ model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
60
+ parameters:
61
+ use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
62
+ infer_batch_size: 25 # Batch size for MSDD inference.
63
+ sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
64
+ seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
65
+ split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
66
+ diar_window_length: 50 # The length of split short sequence when split_infer is True.
67
+ overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
68
+
69
+ asr:
70
+ model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
71
+ parameters:
72
+ asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
73
+ asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
74
+ asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
75
+ decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
76
+ word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
77
+ word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
78
+ fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
79
+ colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
80
+ print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
81
+ break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
82
+
83
+ ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
84
+ pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
85
+ beam_width: 32
86
+ alpha: 0.5
87
+ beta: 2.5
88
+
89
+ realigning_lm_parameters: # Experimental feature
90
+ arpa_language_model: null # Provide a KenLM language model in .arpa format.
91
+ min_number_of_words: 3 # Min number of words for the left context.
92
+ max_number_of_words: 10 # Max number of words for the right context.
93
+ logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
94
+
whisper_diarization_main/nemo_process.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from nemo.collections.asr.models.msdd_models import NeuralDiarizer
6
+ from pydub import AudioSegment
7
+
8
+ from helpers import create_config
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "-a", "--audio", help="name of the target audio file", required=True
13
+ )
14
+ parser.add_argument(
15
+ "--device",
16
+ dest="device",
17
+ default="cuda" if torch.cuda.is_available() else "cpu",
18
+ help="if you have a GPU use 'cuda', otherwise 'cpu'",
19
+ )
20
+ args = parser.parse_args()
21
+
22
+ # convert audio to mono for NeMo combatibility
23
+ sound = AudioSegment.from_file(args.audio).set_channels(1)
24
+ ROOT = os.getcwd()
25
+ temp_path = os.path.join(ROOT, "temp_outputs")
26
+ os.makedirs(temp_path, exist_ok=True)
27
+ sound.export(os.path.join(temp_path, "mono_file.wav"), format="wav")
28
+
29
+ # Initialize NeMo MSDD diarization model
30
+ msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device)
31
+ msdd_model.diarize()
whisper_diarization_main/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ wget
2
+ nemo_toolkit[asr]==1.23.0
3
+ git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560
4
+ git+https://github.com/adefossez/demucs.git
5
+ git+https://github.com/oliverguhr/deepmultilingualpunctuation.git
6
+ git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
7
+
whisper_diarization_main/transcription_helpers.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def transcribe(
5
+ audio_file: str,
6
+ language: str,
7
+ model_name: str,
8
+ compute_dtype: str,
9
+ suppress_numerals: bool,
10
+ device: str,
11
+ ):
12
+ from faster_whisper import WhisperModel
13
+
14
+ from helpers import find_numeral_symbol_tokens, wav2vec2_langs
15
+
16
+ # Faster Whisper non-batched
17
+ # Run on GPU with FP16
18
+ whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)
19
+
20
+ # or run on GPU with INT8
21
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
22
+ # or run on CPU with INT8
23
+ # model = WhisperModel(model_size, device="cpu", compute_type="int8")
24
+
25
+ if suppress_numerals:
26
+ numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
27
+ else:
28
+ numeral_symbol_tokens = None
29
+
30
+ if language is not None and language in wav2vec2_langs:
31
+ word_timestamps = False
32
+ else:
33
+ word_timestamps = True
34
+
35
+ segments, info = whisper_model.transcribe(
36
+ audio_file,
37
+ language=language,
38
+ beam_size=5,
39
+ word_timestamps=word_timestamps, # TODO: disable this if the language is supported by wav2vec2
40
+ suppress_tokens=numeral_symbol_tokens,
41
+ vad_filter=True,
42
+ )
43
+ whisper_results = []
44
+ for segment in segments:
45
+ whisper_results.append(segment._asdict())
46
+ # clear gpu vram
47
+ del whisper_model
48
+ torch.cuda.empty_cache()
49
+ return whisper_results, info.language
50
+
51
+
52
+ def transcribe_batched(
53
+ audio_file: str,
54
+ language: str,
55
+ batch_size: int,
56
+ model_name: str,
57
+ compute_dtype: str,
58
+ suppress_numerals: bool,
59
+ device: str,
60
+ ):
61
+ import whisperx
62
+
63
+ # Faster Whisper batched
64
+ whisper_model = whisperx.load_model(
65
+ model_name,
66
+ device,
67
+ compute_type=compute_dtype,
68
+ asr_options={"suppress_numerals": suppress_numerals},
69
+ )
70
+ audio = whisperx.load_audio(audio_file)
71
+ result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)
72
+ del whisper_model
73
+ torch.cuda.empty_cache()
74
+ return result["segments"], result["language"], audio