Spaces:
Runtime error
Runtime error
Upload 13 files
Browse files- whisper_diarization_main/.gitignore +7 -0
- whisper_diarization_main/LICENSE +24 -0
- whisper_diarization_main/README.md +93 -0
- whisper_diarization_main/Whisper_Transcription_+_NeMo_Diarization.ipynb +1000 -0
- whisper_diarization_main/diarize.py +235 -0
- whisper_diarization_main/diarize_parallel.py +217 -0
- whisper_diarization_main/helpers.py +580 -0
- whisper_diarization_main/nemo_msdd_configs/diar_infer_general.yaml +95 -0
- whisper_diarization_main/nemo_msdd_configs/diar_infer_meeting.yaml +94 -0
- whisper_diarization_main/nemo_msdd_configs/diar_infer_telephonic.yaml +94 -0
- whisper_diarization_main/nemo_process.py +31 -0
- whisper_diarization_main/requirements.txt +7 -0
- whisper_diarization_main/transcription_helpers.py +74 -0
whisper_diarization_main/.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.srt
|
3 |
+
*.txt
|
4 |
+
*.wav
|
5 |
+
*.mp3
|
6 |
+
*.m4a
|
7 |
+
|
whisper_diarization_main/LICENSE
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 2-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, Mahmoud Ashraf
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
16 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
18 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
19 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
20 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
21 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
22 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
23 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
24 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
whisper_diarization_main/README.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">Speaker Diarization Using OpenAI Whisper</h1>
|
2 |
+
<p align="center">
|
3 |
+
<a href="https://github.com/MahmoudAshraf97/whisper-diarization/stargazers">
|
4 |
+
<img src="https://img.shields.io/github/stars/MahmoudAshraf97/whisper-diarization.svg?colorA=orange&colorB=orange&logo=github"
|
5 |
+
alt="GitHub stars">
|
6 |
+
</a>
|
7 |
+
<a href="https://github.com/MahmoudAshraf97/whisper-diarization/issues">
|
8 |
+
<img src="https://img.shields.io/github/issues/MahmoudAshraf97/whisper-diarization.svg"
|
9 |
+
alt="GitHub issues">
|
10 |
+
</a>
|
11 |
+
<a href="https://github.com/MahmoudAshraf97/whisper-diarization/blob/master/LICENSE">
|
12 |
+
<img src="https://img.shields.io/github/license/MahmoudAshraf97/whisper-diarization.svg"
|
13 |
+
alt="GitHub license">
|
14 |
+
</a>
|
15 |
+
<a href="https://twitter.com/intent/tweet?text=&url=https%3A%2F%2Fgithub.com%2FMahmoudAshraf97%2Fwhisper-diarization">
|
16 |
+
<img src="https://img.shields.io/twitter/url/https/github.com/MahmoudAshraf97/whisper-diarization.svg?style=social" alt="Twitter">
|
17 |
+
</a>
|
18 |
+
</a>
|
19 |
+
<a href="https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb">
|
20 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab">
|
21 |
+
</a>
|
22 |
+
|
23 |
+
</p>
|
24 |
+
|
25 |
+
|
26 |
+
#
|
27 |
+
Speaker Diarization pipeline based on OpenAI Whisper
|
28 |
+
I'd like to thank [@m-bain](https://github.com/m-bain) for Batched Whisper Inference, [@mu4farooqi](https://github.com/mu4farooqi) for punctuation realignment algorithm
|
29 |
+
|
30 |
+
<img src="https://github.blog/wp-content/uploads/2020/09/github-stars-logo_Color.png" alt="drawing" width="25"/> **Please, star the project on github (see top-right corner) if you appreciate my contribution to the community!**
|
31 |
+
|
32 |
+
## What is it
|
33 |
+
This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using WhisperX to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by WhisperX to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts.
|
34 |
+
|
35 |
+
|
36 |
+
WhisperX and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later
|
37 |
+
## Installation
|
38 |
+
`PyTorch`, `FFMPEG` and `Cython` are needed as prerequisites to install the requirements
|
39 |
+
```
|
40 |
+
pip install cython torch
|
41 |
+
```
|
42 |
+
or
|
43 |
+
```
|
44 |
+
pip install torch
|
45 |
+
sudo apt update && sudo apt install cython3
|
46 |
+
```
|
47 |
+
```
|
48 |
+
# on Ubuntu or Debian
|
49 |
+
sudo apt update && sudo apt install ffmpeg
|
50 |
+
|
51 |
+
# on Arch Linux
|
52 |
+
sudo pacman -S ffmpeg
|
53 |
+
|
54 |
+
# on MacOS using Homebrew (https://brew.sh/)
|
55 |
+
brew install ffmpeg
|
56 |
+
|
57 |
+
# on Windows using Chocolatey (https://chocolatey.org/)
|
58 |
+
choco install ffmpeg
|
59 |
+
|
60 |
+
# on Windows using Scoop (https://scoop.sh/)
|
61 |
+
scoop install ffmpeg
|
62 |
+
```
|
63 |
+
```
|
64 |
+
pip install -r requirements.txt
|
65 |
+
```
|
66 |
+
## Usage
|
67 |
+
|
68 |
+
```
|
69 |
+
python diarize.py -a AUDIO_FILE_NAME
|
70 |
+
```
|
71 |
+
|
72 |
+
If your system has enough VRAM (>=10GB), you can use `diarize_parallel.py` instead, the difference is that it runs NeMo in parallel with Whisper, this can be beneficial in some cases and the result is the same since the two models are nondependent on each other. This is still experimental, so expect errors and sharp edges. Your feedback is welcome.
|
73 |
+
|
74 |
+
## Command Line Options
|
75 |
+
|
76 |
+
- `-a AUDIO_FILE_NAME`: The name of the audio file to be processed
|
77 |
+
- `--no-stem`: Disables source separation
|
78 |
+
- `--whisper-model`: The model to be used for ASR, default is `medium.en`
|
79 |
+
- `--suppress_numerals`: Transcribes numbers in their pronounced letters instead of digits, improves alignment accuracy
|
80 |
+
- `--device`: Choose which device to use, defaults to "cuda" if available
|
81 |
+
- `--language`: Manually select language, useful if language detection failed
|
82 |
+
- `--batch-size`: Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference
|
83 |
+
|
84 |
+
## Known Limitations
|
85 |
+
- Overlapping speakers are yet to be addressed, a possible approach would be to separate the audio file and isolate only one speaker, then feed it into the pipeline but this will need much more computation
|
86 |
+
- There might be some errors, please raise an issue if you encounter any.
|
87 |
+
|
88 |
+
## Future Improvements
|
89 |
+
- Implement a maximum length per sentence for SRT
|
90 |
+
|
91 |
+
## Acknowledgements
|
92 |
+
Special Thanks for [@adamjonas](https://github.com/adamjonas) for supporting this project
|
93 |
+
This work is based on [OpenAI's Whisper](https://github.com/openai/whisper) , [Faster Whisper](https://github.com/guillaumekln/faster-whisper) , [Nvidia NeMo](https://github.com/NVIDIA/NeMo) , and [Facebook's Demucs](https://github.com/facebookresearch/demucs)
|
whisper_diarization_main/Whisper_Transcription_+_NeMo_Diarization.ipynb
ADDED
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {
|
7 |
+
"colab_type": "text",
|
8 |
+
"id": "view-in-github"
|
9 |
+
},
|
10 |
+
"source": [
|
11 |
+
"<a href=\"https://colab.research.google.com/github/MahmoudAshraf97/whisper-diarization/blob/main/Whisper_Transcription_%2B_NeMo_Diarization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"attachments": {},
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"metadata": {
|
18 |
+
"id": "eCmjcOc9yEtQ"
|
19 |
+
},
|
20 |
+
"source": [
|
21 |
+
"# Installing Dependencies"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": null,
|
27 |
+
"metadata": {
|
28 |
+
"id": "Tn1c-CoDv2kw"
|
29 |
+
},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"!pip install git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n",
|
33 |
+
"!pip install --no-build-isolation nemo_toolkit[asr]==1.23.0\n",
|
34 |
+
"!pip install --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n",
|
35 |
+
"!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n",
|
36 |
+
"!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": null,
|
42 |
+
"metadata": {
|
43 |
+
"id": "YzhncHP0ytbQ"
|
44 |
+
},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import os\n",
|
48 |
+
"import wget\n",
|
49 |
+
"from omegaconf import OmegaConf\n",
|
50 |
+
"import json\n",
|
51 |
+
"import shutil\n",
|
52 |
+
"import torch\n",
|
53 |
+
"import torchaudio\n",
|
54 |
+
"from nemo.collections.asr.models.msdd_models import NeuralDiarizer\n",
|
55 |
+
"from deepmultilingualpunctuation import PunctuationModel\n",
|
56 |
+
"import re\n",
|
57 |
+
"import logging\n",
|
58 |
+
"import nltk\n",
|
59 |
+
"from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE\n",
|
60 |
+
"from ctc_forced_aligner import (\n",
|
61 |
+
" load_alignment_model,\n",
|
62 |
+
" generate_emissions,\n",
|
63 |
+
" preprocess_text,\n",
|
64 |
+
" get_alignments,\n",
|
65 |
+
" get_spans,\n",
|
66 |
+
" postprocess_results,\n",
|
67 |
+
")"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"attachments": {},
|
72 |
+
"cell_type": "markdown",
|
73 |
+
"metadata": {
|
74 |
+
"id": "jbsUt3SwyhjD"
|
75 |
+
},
|
76 |
+
"source": [
|
77 |
+
"# Helper Functions"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {
|
84 |
+
"id": "Se6Hc7CZygxu"
|
85 |
+
},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"punct_model_langs = [\n",
|
89 |
+
" \"en\",\n",
|
90 |
+
" \"fr\",\n",
|
91 |
+
" \"de\",\n",
|
92 |
+
" \"es\",\n",
|
93 |
+
" \"it\",\n",
|
94 |
+
" \"nl\",\n",
|
95 |
+
" \"pt\",\n",
|
96 |
+
" \"bg\",\n",
|
97 |
+
" \"pl\",\n",
|
98 |
+
" \"cs\",\n",
|
99 |
+
" \"sk\",\n",
|
100 |
+
" \"sl\",\n",
|
101 |
+
"]\n",
|
102 |
+
"langs_to_iso = {\n",
|
103 |
+
" \"af\": \"afr\",\n",
|
104 |
+
" \"am\": \"amh\",\n",
|
105 |
+
" \"ar\": \"ara\",\n",
|
106 |
+
" \"as\": \"asm\",\n",
|
107 |
+
" \"az\": \"aze\",\n",
|
108 |
+
" \"ba\": \"bak\",\n",
|
109 |
+
" \"be\": \"bel\",\n",
|
110 |
+
" \"bg\": \"bul\",\n",
|
111 |
+
" \"bn\": \"ben\",\n",
|
112 |
+
" \"bo\": \"tib\",\n",
|
113 |
+
" \"br\": \"bre\",\n",
|
114 |
+
" \"bs\": \"bos\",\n",
|
115 |
+
" \"ca\": \"cat\",\n",
|
116 |
+
" \"cs\": \"cze\",\n",
|
117 |
+
" \"cy\": \"wel\",\n",
|
118 |
+
" \"da\": \"dan\",\n",
|
119 |
+
" \"de\": \"ger\",\n",
|
120 |
+
" \"el\": \"gre\",\n",
|
121 |
+
" \"en\": \"eng\",\n",
|
122 |
+
" \"es\": \"spa\",\n",
|
123 |
+
" \"et\": \"est\",\n",
|
124 |
+
" \"eu\": \"baq\",\n",
|
125 |
+
" \"fa\": \"per\",\n",
|
126 |
+
" \"fi\": \"fin\",\n",
|
127 |
+
" \"fo\": \"fao\",\n",
|
128 |
+
" \"fr\": \"fre\",\n",
|
129 |
+
" \"gl\": \"glg\",\n",
|
130 |
+
" \"gu\": \"guj\",\n",
|
131 |
+
" \"ha\": \"hau\",\n",
|
132 |
+
" \"haw\": \"haw\",\n",
|
133 |
+
" \"he\": \"heb\",\n",
|
134 |
+
" \"hi\": \"hin\",\n",
|
135 |
+
" \"hr\": \"hrv\",\n",
|
136 |
+
" \"ht\": \"hat\",\n",
|
137 |
+
" \"hu\": \"hun\",\n",
|
138 |
+
" \"hy\": \"arm\",\n",
|
139 |
+
" \"id\": \"ind\",\n",
|
140 |
+
" \"is\": \"ice\",\n",
|
141 |
+
" \"it\": \"ita\",\n",
|
142 |
+
" \"ja\": \"jpn\",\n",
|
143 |
+
" \"jw\": \"jav\",\n",
|
144 |
+
" \"ka\": \"geo\",\n",
|
145 |
+
" \"kk\": \"kaz\",\n",
|
146 |
+
" \"km\": \"khm\",\n",
|
147 |
+
" \"kn\": \"kan\",\n",
|
148 |
+
" \"ko\": \"kor\",\n",
|
149 |
+
" \"la\": \"lat\",\n",
|
150 |
+
" \"lb\": \"ltz\",\n",
|
151 |
+
" \"ln\": \"lin\",\n",
|
152 |
+
" \"lo\": \"lao\",\n",
|
153 |
+
" \"lt\": \"lit\",\n",
|
154 |
+
" \"lv\": \"lav\",\n",
|
155 |
+
" \"mg\": \"mlg\",\n",
|
156 |
+
" \"mi\": \"mao\",\n",
|
157 |
+
" \"mk\": \"mac\",\n",
|
158 |
+
" \"ml\": \"mal\",\n",
|
159 |
+
" \"mn\": \"mon\",\n",
|
160 |
+
" \"mr\": \"mar\",\n",
|
161 |
+
" \"ms\": \"may\",\n",
|
162 |
+
" \"mt\": \"mlt\",\n",
|
163 |
+
" \"my\": \"bur\",\n",
|
164 |
+
" \"ne\": \"nep\",\n",
|
165 |
+
" \"nl\": \"dut\",\n",
|
166 |
+
" \"nn\": \"nno\",\n",
|
167 |
+
" \"no\": \"nor\",\n",
|
168 |
+
" \"oc\": \"oci\",\n",
|
169 |
+
" \"pa\": \"pan\",\n",
|
170 |
+
" \"pl\": \"pol\",\n",
|
171 |
+
" \"ps\": \"pus\",\n",
|
172 |
+
" \"pt\": \"por\",\n",
|
173 |
+
" \"ro\": \"rum\",\n",
|
174 |
+
" \"ru\": \"rus\",\n",
|
175 |
+
" \"sa\": \"san\",\n",
|
176 |
+
" \"sd\": \"snd\",\n",
|
177 |
+
" \"si\": \"sin\",\n",
|
178 |
+
" \"sk\": \"slo\",\n",
|
179 |
+
" \"sl\": \"slv\",\n",
|
180 |
+
" \"sn\": \"sna\",\n",
|
181 |
+
" \"so\": \"som\",\n",
|
182 |
+
" \"sq\": \"alb\",\n",
|
183 |
+
" \"sr\": \"srp\",\n",
|
184 |
+
" \"su\": \"sun\",\n",
|
185 |
+
" \"sv\": \"swe\",\n",
|
186 |
+
" \"sw\": \"swa\",\n",
|
187 |
+
" \"ta\": \"tam\",\n",
|
188 |
+
" \"te\": \"tel\",\n",
|
189 |
+
" \"tg\": \"tgk\",\n",
|
190 |
+
" \"th\": \"tha\",\n",
|
191 |
+
" \"tk\": \"tuk\",\n",
|
192 |
+
" \"tl\": \"tgl\",\n",
|
193 |
+
" \"tr\": \"tur\",\n",
|
194 |
+
" \"tt\": \"tat\",\n",
|
195 |
+
" \"uk\": \"ukr\",\n",
|
196 |
+
" \"ur\": \"urd\",\n",
|
197 |
+
" \"uz\": \"uzb\",\n",
|
198 |
+
" \"vi\": \"vie\",\n",
|
199 |
+
" \"yi\": \"yid\",\n",
|
200 |
+
" \"yo\": \"yor\",\n",
|
201 |
+
" \"yue\": \"yue\",\n",
|
202 |
+
" \"zh\": \"chi\",\n",
|
203 |
+
"}\n",
|
204 |
+
"\n",
|
205 |
+
"\n",
|
206 |
+
"whisper_langs = sorted(LANGUAGES.keys()) + sorted(\n",
|
207 |
+
" [k.title() for k in TO_LANGUAGE_CODE.keys()]\n",
|
208 |
+
")\n",
|
209 |
+
"\n",
|
210 |
+
"\n",
|
211 |
+
"def create_config(output_dir):\n",
|
212 |
+
" DOMAIN_TYPE = \"telephonic\" # Can be meeting, telephonic, or general based on domain type of the audio file\n",
|
213 |
+
" CONFIG_FILE_NAME = f\"diar_infer_{DOMAIN_TYPE}.yaml\"\n",
|
214 |
+
" CONFIG_URL = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}\"\n",
|
215 |
+
" MODEL_CONFIG = os.path.join(output_dir, CONFIG_FILE_NAME)\n",
|
216 |
+
" if not os.path.exists(MODEL_CONFIG):\n",
|
217 |
+
" MODEL_CONFIG = wget.download(CONFIG_URL, output_dir)\n",
|
218 |
+
"\n",
|
219 |
+
" config = OmegaConf.load(MODEL_CONFIG)\n",
|
220 |
+
"\n",
|
221 |
+
" data_dir = os.path.join(output_dir, \"data\")\n",
|
222 |
+
" os.makedirs(data_dir, exist_ok=True)\n",
|
223 |
+
"\n",
|
224 |
+
" meta = {\n",
|
225 |
+
" \"audio_filepath\": os.path.join(output_dir, \"mono_file.wav\"),\n",
|
226 |
+
" \"offset\": 0,\n",
|
227 |
+
" \"duration\": None,\n",
|
228 |
+
" \"label\": \"infer\",\n",
|
229 |
+
" \"text\": \"-\",\n",
|
230 |
+
" \"rttm_filepath\": None,\n",
|
231 |
+
" \"uem_filepath\": None,\n",
|
232 |
+
" }\n",
|
233 |
+
" with open(os.path.join(data_dir, \"input_manifest.json\"), \"w\") as fp:\n",
|
234 |
+
" json.dump(meta, fp)\n",
|
235 |
+
" fp.write(\"\\n\")\n",
|
236 |
+
"\n",
|
237 |
+
" pretrained_vad = \"vad_multilingual_marblenet\"\n",
|
238 |
+
" pretrained_speaker_model = \"titanet_large\"\n",
|
239 |
+
" config.num_workers = 0 # Workaround for multiprocessing hanging with ipython issue\n",
|
240 |
+
" config.diarizer.manifest_filepath = os.path.join(data_dir, \"input_manifest.json\")\n",
|
241 |
+
" config.diarizer.out_dir = (\n",
|
242 |
+
" output_dir # Directory to store intermediate files and prediction outputs\n",
|
243 |
+
" )\n",
|
244 |
+
"\n",
|
245 |
+
" config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model\n",
|
246 |
+
" config.diarizer.oracle_vad = (\n",
|
247 |
+
" False # compute VAD provided with model_path to vad config\n",
|
248 |
+
" )\n",
|
249 |
+
" config.diarizer.clustering.parameters.oracle_num_speakers = False\n",
|
250 |
+
"\n",
|
251 |
+
" # Here, we use our in-house pretrained NeMo VAD model\n",
|
252 |
+
" config.diarizer.vad.model_path = pretrained_vad\n",
|
253 |
+
" config.diarizer.vad.parameters.onset = 0.8\n",
|
254 |
+
" config.diarizer.vad.parameters.offset = 0.6\n",
|
255 |
+
" config.diarizer.vad.parameters.pad_offset = -0.05\n",
|
256 |
+
" config.diarizer.msdd_model.model_path = (\n",
|
257 |
+
" \"diar_msdd_telephonic\" # Telephonic speaker diarization model\n",
|
258 |
+
" )\n",
|
259 |
+
"\n",
|
260 |
+
" return config\n",
|
261 |
+
"\n",
|
262 |
+
"\n",
|
263 |
+
"def get_word_ts_anchor(s, e, option=\"start\"):\n",
|
264 |
+
" if option == \"end\":\n",
|
265 |
+
" return e\n",
|
266 |
+
" elif option == \"mid\":\n",
|
267 |
+
" return (s + e) / 2\n",
|
268 |
+
" return s\n",
|
269 |
+
"\n",
|
270 |
+
"\n",
|
271 |
+
"def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option=\"start\"):\n",
|
272 |
+
" s, e, sp = spk_ts[0]\n",
|
273 |
+
" wrd_pos, turn_idx = 0, 0\n",
|
274 |
+
" wrd_spk_mapping = []\n",
|
275 |
+
" for wrd_dict in wrd_ts:\n",
|
276 |
+
" ws, we, wrd = (\n",
|
277 |
+
" int(wrd_dict[\"start\"] * 1000),\n",
|
278 |
+
" int(wrd_dict[\"end\"] * 1000),\n",
|
279 |
+
" wrd_dict[\"text\"],\n",
|
280 |
+
" )\n",
|
281 |
+
" wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)\n",
|
282 |
+
" while wrd_pos > float(e):\n",
|
283 |
+
" turn_idx += 1\n",
|
284 |
+
" turn_idx = min(turn_idx, len(spk_ts) - 1)\n",
|
285 |
+
" s, e, sp = spk_ts[turn_idx]\n",
|
286 |
+
" if turn_idx == len(spk_ts) - 1:\n",
|
287 |
+
" e = get_word_ts_anchor(ws, we, option=\"end\")\n",
|
288 |
+
" wrd_spk_mapping.append(\n",
|
289 |
+
" {\"word\": wrd, \"start_time\": ws, \"end_time\": we, \"speaker\": sp}\n",
|
290 |
+
" )\n",
|
291 |
+
" return wrd_spk_mapping\n",
|
292 |
+
"\n",
|
293 |
+
"\n",
|
294 |
+
"sentence_ending_punctuations = \".?!\"\n",
|
295 |
+
"\n",
|
296 |
+
"\n",
|
297 |
+
"def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):\n",
|
298 |
+
" is_word_sentence_end = (\n",
|
299 |
+
" lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
|
300 |
+
" )\n",
|
301 |
+
" left_idx = word_idx\n",
|
302 |
+
" while (\n",
|
303 |
+
" left_idx > 0\n",
|
304 |
+
" and word_idx - left_idx < max_words\n",
|
305 |
+
" and speaker_list[left_idx - 1] == speaker_list[left_idx]\n",
|
306 |
+
" and not is_word_sentence_end(left_idx - 1)\n",
|
307 |
+
" ):\n",
|
308 |
+
" left_idx -= 1\n",
|
309 |
+
"\n",
|
310 |
+
" return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1\n",
|
311 |
+
"\n",
|
312 |
+
"\n",
|
313 |
+
"def get_last_word_idx_of_sentence(word_idx, word_list, max_words):\n",
|
314 |
+
" is_word_sentence_end = (\n",
|
315 |
+
" lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations\n",
|
316 |
+
" )\n",
|
317 |
+
" right_idx = word_idx\n",
|
318 |
+
" while (\n",
|
319 |
+
" right_idx < len(word_list) - 1\n",
|
320 |
+
" and right_idx - word_idx < max_words\n",
|
321 |
+
" and not is_word_sentence_end(right_idx)\n",
|
322 |
+
" ):\n",
|
323 |
+
" right_idx += 1\n",
|
324 |
+
"\n",
|
325 |
+
" return (\n",
|
326 |
+
" right_idx\n",
|
327 |
+
" if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)\n",
|
328 |
+
" else -1\n",
|
329 |
+
" )\n",
|
330 |
+
"\n",
|
331 |
+
"\n",
|
332 |
+
"def get_realigned_ws_mapping_with_punctuation(\n",
|
333 |
+
" word_speaker_mapping, max_words_in_sentence=50\n",
|
334 |
+
"):\n",
|
335 |
+
" is_word_sentence_end = (\n",
|
336 |
+
" lambda x: x >= 0\n",
|
337 |
+
" and word_speaker_mapping[x][\"word\"][-1] in sentence_ending_punctuations\n",
|
338 |
+
" )\n",
|
339 |
+
" wsp_len = len(word_speaker_mapping)\n",
|
340 |
+
"\n",
|
341 |
+
" words_list, speaker_list = [], []\n",
|
342 |
+
" for k, line_dict in enumerate(word_speaker_mapping):\n",
|
343 |
+
" word, speaker = line_dict[\"word\"], line_dict[\"speaker\"]\n",
|
344 |
+
" words_list.append(word)\n",
|
345 |
+
" speaker_list.append(speaker)\n",
|
346 |
+
"\n",
|
347 |
+
" k = 0\n",
|
348 |
+
" while k < len(word_speaker_mapping):\n",
|
349 |
+
" line_dict = word_speaker_mapping[k]\n",
|
350 |
+
" if (\n",
|
351 |
+
" k < wsp_len - 1\n",
|
352 |
+
" and speaker_list[k] != speaker_list[k + 1]\n",
|
353 |
+
" and not is_word_sentence_end(k)\n",
|
354 |
+
" ):\n",
|
355 |
+
" left_idx = get_first_word_idx_of_sentence(\n",
|
356 |
+
" k, words_list, speaker_list, max_words_in_sentence\n",
|
357 |
+
" )\n",
|
358 |
+
" right_idx = (\n",
|
359 |
+
" get_last_word_idx_of_sentence(\n",
|
360 |
+
" k, words_list, max_words_in_sentence - k + left_idx - 1\n",
|
361 |
+
" )\n",
|
362 |
+
" if left_idx > -1\n",
|
363 |
+
" else -1\n",
|
364 |
+
" )\n",
|
365 |
+
" if min(left_idx, right_idx) == -1:\n",
|
366 |
+
" k += 1\n",
|
367 |
+
" continue\n",
|
368 |
+
"\n",
|
369 |
+
" spk_labels = speaker_list[left_idx : right_idx + 1]\n",
|
370 |
+
" mod_speaker = max(set(spk_labels), key=spk_labels.count)\n",
|
371 |
+
" if spk_labels.count(mod_speaker) < len(spk_labels) // 2:\n",
|
372 |
+
" k += 1\n",
|
373 |
+
" continue\n",
|
374 |
+
"\n",
|
375 |
+
" speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (\n",
|
376 |
+
" right_idx - left_idx + 1\n",
|
377 |
+
" )\n",
|
378 |
+
" k = right_idx\n",
|
379 |
+
"\n",
|
380 |
+
" k += 1\n",
|
381 |
+
"\n",
|
382 |
+
" k, realigned_list = 0, []\n",
|
383 |
+
" while k < len(word_speaker_mapping):\n",
|
384 |
+
" line_dict = word_speaker_mapping[k].copy()\n",
|
385 |
+
" line_dict[\"speaker\"] = speaker_list[k]\n",
|
386 |
+
" realigned_list.append(line_dict)\n",
|
387 |
+
" k += 1\n",
|
388 |
+
"\n",
|
389 |
+
" return realigned_list\n",
|
390 |
+
"\n",
|
391 |
+
"\n",
|
392 |
+
"def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):\n",
|
393 |
+
" sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak\n",
|
394 |
+
" s, e, spk = spk_ts[0]\n",
|
395 |
+
" prev_spk = spk\n",
|
396 |
+
"\n",
|
397 |
+
" snts = []\n",
|
398 |
+
" snt = {\"speaker\": f\"Speaker {spk}\", \"start_time\": s, \"end_time\": e, \"text\": \"\"}\n",
|
399 |
+
"\n",
|
400 |
+
" for wrd_dict in word_speaker_mapping:\n",
|
401 |
+
" wrd, spk = wrd_dict[\"word\"], wrd_dict[\"speaker\"]\n",
|
402 |
+
" s, e = wrd_dict[\"start_time\"], wrd_dict[\"end_time\"]\n",
|
403 |
+
" if spk != prev_spk or sentence_checker(snt[\"text\"] + \" \" + wrd):\n",
|
404 |
+
" snts.append(snt)\n",
|
405 |
+
" snt = {\n",
|
406 |
+
" \"speaker\": f\"Speaker {spk}\",\n",
|
407 |
+
" \"start_time\": s,\n",
|
408 |
+
" \"end_time\": e,\n",
|
409 |
+
" \"text\": \"\",\n",
|
410 |
+
" }\n",
|
411 |
+
" else:\n",
|
412 |
+
" snt[\"end_time\"] = e\n",
|
413 |
+
" snt[\"text\"] += wrd + \" \"\n",
|
414 |
+
" prev_spk = spk\n",
|
415 |
+
"\n",
|
416 |
+
" snts.append(snt)\n",
|
417 |
+
" return snts\n",
|
418 |
+
"\n",
|
419 |
+
"\n",
|
420 |
+
"def get_speaker_aware_transcript(sentences_speaker_mapping, f):\n",
|
421 |
+
" previous_speaker = sentences_speaker_mapping[0][\"speaker\"]\n",
|
422 |
+
" f.write(f\"{previous_speaker}: \")\n",
|
423 |
+
"\n",
|
424 |
+
" for sentence_dict in sentences_speaker_mapping:\n",
|
425 |
+
" speaker = sentence_dict[\"speaker\"]\n",
|
426 |
+
" sentence = sentence_dict[\"text\"]\n",
|
427 |
+
"\n",
|
428 |
+
" # If this speaker doesn't match the previous one, start a new paragraph\n",
|
429 |
+
" if speaker != previous_speaker:\n",
|
430 |
+
" f.write(f\"\\n\\n{speaker}: \")\n",
|
431 |
+
" previous_speaker = speaker\n",
|
432 |
+
"\n",
|
433 |
+
" # No matter what, write the current sentence\n",
|
434 |
+
" f.write(sentence + \" \")\n",
|
435 |
+
"\n",
|
436 |
+
"\n",
|
437 |
+
"def format_timestamp(\n",
|
438 |
+
" milliseconds: float, always_include_hours: bool = False, decimal_marker: str = \".\"\n",
|
439 |
+
"):\n",
|
440 |
+
" assert milliseconds >= 0, \"non-negative timestamp expected\"\n",
|
441 |
+
"\n",
|
442 |
+
" hours = milliseconds // 3_600_000\n",
|
443 |
+
" milliseconds -= hours * 3_600_000\n",
|
444 |
+
"\n",
|
445 |
+
" minutes = milliseconds // 60_000\n",
|
446 |
+
" milliseconds -= minutes * 60_000\n",
|
447 |
+
"\n",
|
448 |
+
" seconds = milliseconds // 1_000\n",
|
449 |
+
" milliseconds -= seconds * 1_000\n",
|
450 |
+
"\n",
|
451 |
+
" hours_marker = f\"{hours:02d}:\" if always_include_hours or hours > 0 else \"\"\n",
|
452 |
+
" return (\n",
|
453 |
+
" f\"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}\"\n",
|
454 |
+
" )\n",
|
455 |
+
"\n",
|
456 |
+
"\n",
|
457 |
+
"def write_srt(transcript, file):\n",
|
458 |
+
" \"\"\"\n",
|
459 |
+
" Write a transcript to a file in SRT format.\n",
|
460 |
+
"\n",
|
461 |
+
" \"\"\"\n",
|
462 |
+
" for i, segment in enumerate(transcript, start=1):\n",
|
463 |
+
" # write srt lines\n",
|
464 |
+
" print(\n",
|
465 |
+
" f\"{i}\\n\"\n",
|
466 |
+
" f\"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> \"\n",
|
467 |
+
" f\"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\\n\"\n",
|
468 |
+
" f\"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\\n\",\n",
|
469 |
+
" file=file,\n",
|
470 |
+
" flush=True,\n",
|
471 |
+
" )\n",
|
472 |
+
"\n",
|
473 |
+
"\n",
|
474 |
+
"def find_numeral_symbol_tokens(tokenizer):\n",
|
475 |
+
" numeral_symbol_tokens = [\n",
|
476 |
+
" -1,\n",
|
477 |
+
" ]\n",
|
478 |
+
" for token, token_id in tokenizer.get_vocab().items():\n",
|
479 |
+
" has_numeral_symbol = any(c in \"0123456789%$£\" for c in token)\n",
|
480 |
+
" if has_numeral_symbol:\n",
|
481 |
+
" numeral_symbol_tokens.append(token_id)\n",
|
482 |
+
" return numeral_symbol_tokens\n",
|
483 |
+
"\n",
|
484 |
+
"\n",
|
485 |
+
"def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):\n",
|
486 |
+
" # if current word is the last word\n",
|
487 |
+
" if current_word_index == len(word_timestamps) - 1:\n",
|
488 |
+
" return word_timestamps[current_word_index][\"start\"]\n",
|
489 |
+
"\n",
|
490 |
+
" next_word_index = current_word_index + 1\n",
|
491 |
+
" while current_word_index < len(word_timestamps) - 1:\n",
|
492 |
+
" if word_timestamps[next_word_index].get(\"start\") is None:\n",
|
493 |
+
" # if next word doesn't have a start timestamp\n",
|
494 |
+
" # merge it with the current word and delete it\n",
|
495 |
+
" word_timestamps[current_word_index][\"word\"] += (\n",
|
496 |
+
" \" \" + word_timestamps[next_word_index][\"word\"]\n",
|
497 |
+
" )\n",
|
498 |
+
"\n",
|
499 |
+
" word_timestamps[next_word_index][\"word\"] = None\n",
|
500 |
+
" next_word_index += 1\n",
|
501 |
+
" if next_word_index == len(word_timestamps):\n",
|
502 |
+
" return final_timestamp\n",
|
503 |
+
"\n",
|
504 |
+
" else:\n",
|
505 |
+
" return word_timestamps[next_word_index][\"start\"]\n",
|
506 |
+
"\n",
|
507 |
+
"\n",
|
508 |
+
"def filter_missing_timestamps(\n",
|
509 |
+
" word_timestamps, initial_timestamp=0, final_timestamp=None\n",
|
510 |
+
"):\n",
|
511 |
+
" # handle the first and last word\n",
|
512 |
+
" if word_timestamps[0].get(\"start\") is None:\n",
|
513 |
+
" word_timestamps[0][\"start\"] = (\n",
|
514 |
+
" initial_timestamp if initial_timestamp is not None else 0\n",
|
515 |
+
" )\n",
|
516 |
+
" word_timestamps[0][\"end\"] = _get_next_start_timestamp(\n",
|
517 |
+
" word_timestamps, 0, final_timestamp\n",
|
518 |
+
" )\n",
|
519 |
+
"\n",
|
520 |
+
" result = [\n",
|
521 |
+
" word_timestamps[0],\n",
|
522 |
+
" ]\n",
|
523 |
+
"\n",
|
524 |
+
" for i, ws in enumerate(word_timestamps[1:], start=1):\n",
|
525 |
+
" # if ws doesn't have a start and end\n",
|
526 |
+
" # use the previous end as start and next start as end\n",
|
527 |
+
" if ws.get(\"start\") is None and ws.get(\"word\") is not None:\n",
|
528 |
+
" ws[\"start\"] = word_timestamps[i - 1][\"end\"]\n",
|
529 |
+
" ws[\"end\"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)\n",
|
530 |
+
"\n",
|
531 |
+
" if ws[\"word\"] is not None:\n",
|
532 |
+
" result.append(ws)\n",
|
533 |
+
" return result\n",
|
534 |
+
"\n",
|
535 |
+
"\n",
|
536 |
+
"def cleanup(path: str):\n",
|
537 |
+
" \"\"\"path could either be relative or absolute.\"\"\"\n",
|
538 |
+
" # check if file or directory exists\n",
|
539 |
+
" if os.path.isfile(path) or os.path.islink(path):\n",
|
540 |
+
" # remove file\n",
|
541 |
+
" os.remove(path)\n",
|
542 |
+
" elif os.path.isdir(path):\n",
|
543 |
+
" # remove directory and all its content\n",
|
544 |
+
" shutil.rmtree(path)\n",
|
545 |
+
" else:\n",
|
546 |
+
" raise ValueError(\"Path {} is not a file or dir.\".format(path))\n",
|
547 |
+
"\n",
|
548 |
+
"\n",
|
549 |
+
"def process_language_arg(language: str, model_name: str):\n",
|
550 |
+
" \"\"\"\n",
|
551 |
+
" Process the language argument to make sure it's valid and convert language names to language codes.\n",
|
552 |
+
" \"\"\"\n",
|
553 |
+
" if language is not None:\n",
|
554 |
+
" language = language.lower()\n",
|
555 |
+
" if language not in LANGUAGES:\n",
|
556 |
+
" if language in TO_LANGUAGE_CODE:\n",
|
557 |
+
" language = TO_LANGUAGE_CODE[language]\n",
|
558 |
+
" else:\n",
|
559 |
+
" raise ValueError(f\"Unsupported language: {language}\")\n",
|
560 |
+
"\n",
|
561 |
+
" if model_name.endswith(\".en\") and language != \"en\":\n",
|
562 |
+
" if language is not None:\n",
|
563 |
+
" logging.warning(\n",
|
564 |
+
" f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
|
565 |
+
" )\n",
|
566 |
+
" language = \"en\"\n",
|
567 |
+
" return language\n",
|
568 |
+
"\n",
|
569 |
+
"\n",
|
570 |
+
"def transcribe_batched(\n",
|
571 |
+
" audio_file: str,\n",
|
572 |
+
" language: str,\n",
|
573 |
+
" batch_size: int,\n",
|
574 |
+
" model_name: str,\n",
|
575 |
+
" compute_dtype: str,\n",
|
576 |
+
" suppress_numerals: bool,\n",
|
577 |
+
" device: str,\n",
|
578 |
+
"):\n",
|
579 |
+
" import whisperx\n",
|
580 |
+
"\n",
|
581 |
+
" # Faster Whisper batched\n",
|
582 |
+
" whisper_model = whisperx.load_model(\n",
|
583 |
+
" model_name,\n",
|
584 |
+
" device,\n",
|
585 |
+
" compute_type=compute_dtype,\n",
|
586 |
+
" asr_options={\"suppress_numerals\": suppress_numerals},\n",
|
587 |
+
" )\n",
|
588 |
+
" audio = whisperx.load_audio(audio_file)\n",
|
589 |
+
" result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
|
590 |
+
" del whisper_model\n",
|
591 |
+
" torch.cuda.empty_cache()\n",
|
592 |
+
" return result[\"segments\"], result[\"language\"], audio"
|
593 |
+
]
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"attachments": {},
|
597 |
+
"cell_type": "markdown",
|
598 |
+
"metadata": {
|
599 |
+
"id": "B7qWQb--1Xcw"
|
600 |
+
},
|
601 |
+
"source": [
|
602 |
+
"# Options"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"cell_type": "code",
|
607 |
+
"execution_count": null,
|
608 |
+
"metadata": {
|
609 |
+
"id": "ONlFrSnD0FOp"
|
610 |
+
},
|
611 |
+
"outputs": [],
|
612 |
+
"source": [
|
613 |
+
"# Name of the audio file\n",
|
614 |
+
"audio_path = \"20200128-Pieter Wuille (part 1 of 2) - Episode 1.mp3\"\n",
|
615 |
+
"\n",
|
616 |
+
"# Whether to enable music removal from speech, helps increase diarization quality but uses alot of ram\n",
|
617 |
+
"enable_stemming = True\n",
|
618 |
+
"\n",
|
619 |
+
"# (choose from 'tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large-v3', 'large')\n",
|
620 |
+
"whisper_model_name = \"large-v2\"\n",
|
621 |
+
"\n",
|
622 |
+
"# replaces numerical digits with their pronounciation, increases diarization accuracy\n",
|
623 |
+
"suppress_numerals = True\n",
|
624 |
+
"\n",
|
625 |
+
"batch_size = 8\n",
|
626 |
+
"\n",
|
627 |
+
"language = None # autodetect language\n",
|
628 |
+
"\n",
|
629 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
630 |
+
]
|
631 |
+
},
|
632 |
+
{
|
633 |
+
"attachments": {},
|
634 |
+
"cell_type": "markdown",
|
635 |
+
"metadata": {
|
636 |
+
"id": "h-cY1ZEy2KVI"
|
637 |
+
},
|
638 |
+
"source": [
|
639 |
+
"# Processing"
|
640 |
+
]
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"attachments": {},
|
644 |
+
"cell_type": "markdown",
|
645 |
+
"metadata": {
|
646 |
+
"id": "7ZS4xXmE2NGP"
|
647 |
+
},
|
648 |
+
"source": [
|
649 |
+
"## Separating music from speech using Demucs\n",
|
650 |
+
"\n",
|
651 |
+
"---\n",
|
652 |
+
"\n",
|
653 |
+
"By isolating the vocals from the rest of the audio, it becomes easier to identify and track individual speakers based on the spectral and temporal characteristics of their speech signals. Source separation is just one of many techniques that can be used as a preprocessing step to help improve the accuracy and reliability of the overall diarization process."
|
654 |
+
]
|
655 |
+
},
|
656 |
+
{
|
657 |
+
"cell_type": "code",
|
658 |
+
"execution_count": null,
|
659 |
+
"metadata": {
|
660 |
+
"colab": {
|
661 |
+
"base_uri": "https://localhost:8080/"
|
662 |
+
},
|
663 |
+
"id": "HKcgQUrAzsJZ",
|
664 |
+
"outputId": "dc2a1d96-20da-4749-9d64-21edacfba1b1"
|
665 |
+
},
|
666 |
+
"outputs": [],
|
667 |
+
"source": [
|
668 |
+
"if enable_stemming:\n",
|
669 |
+
" # Isolate vocals from the rest of the audio\n",
|
670 |
+
"\n",
|
671 |
+
" return_code = os.system(\n",
|
672 |
+
" f'python3 -m demucs.separate -n htdemucs --two-stems=vocals \"{audio_path}\" -o \"temp_outputs\"'\n",
|
673 |
+
" )\n",
|
674 |
+
"\n",
|
675 |
+
" if return_code != 0:\n",
|
676 |
+
" logging.warning(\"Source splitting failed, using original audio file.\")\n",
|
677 |
+
" vocal_target = audio_path\n",
|
678 |
+
" else:\n",
|
679 |
+
" vocal_target = os.path.join(\n",
|
680 |
+
" \"temp_outputs\",\n",
|
681 |
+
" \"htdemucs\",\n",
|
682 |
+
" os.path.splitext(os.path.basename(audio_path))[0],\n",
|
683 |
+
" \"vocals.wav\",\n",
|
684 |
+
" )\n",
|
685 |
+
"else:\n",
|
686 |
+
" vocal_target = audio_path"
|
687 |
+
]
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"attachments": {},
|
691 |
+
"cell_type": "markdown",
|
692 |
+
"metadata": {
|
693 |
+
"id": "UYg9VWb22Tz8"
|
694 |
+
},
|
695 |
+
"source": [
|
696 |
+
"## Transcriping audio using Whisper and realligning timestamps using Wav2Vec2\n",
|
697 |
+
"---\n",
|
698 |
+
"This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.\n",
|
699 |
+
"\n",
|
700 |
+
"The first model is called OpenAI Whisper, which is a speech recognition model that can transcribe speech with high accuracy. The code loads the whisper model and uses it to transcribe the vocal_target file.\n",
|
701 |
+
"\n",
|
702 |
+
"The output of the transcription process is a set of text segments with corresponding timestamps indicating when each segment was spoken.\n"
|
703 |
+
]
|
704 |
+
},
|
705 |
+
{
|
706 |
+
"cell_type": "code",
|
707 |
+
"execution_count": null,
|
708 |
+
"metadata": {
|
709 |
+
"id": "5-VKFn530oTl"
|
710 |
+
},
|
711 |
+
"outputs": [],
|
712 |
+
"source": [
|
713 |
+
"compute_type = \"float16\"\n",
|
714 |
+
"# or run on GPU with INT8\n",
|
715 |
+
"# compute_type = \"int8_float16\"\n",
|
716 |
+
"# or run on CPU with INT8\n",
|
717 |
+
"# compute_type = \"int8\"\n",
|
718 |
+
"\n",
|
719 |
+
"whisper_results, language, audio_waveform = transcribe_batched(\n",
|
720 |
+
" vocal_target,\n",
|
721 |
+
" language,\n",
|
722 |
+
" batch_size,\n",
|
723 |
+
" whisper_model_name,\n",
|
724 |
+
" compute_type,\n",
|
725 |
+
" suppress_numerals,\n",
|
726 |
+
" device,\n",
|
727 |
+
")"
|
728 |
+
]
|
729 |
+
},
|
730 |
+
{
|
731 |
+
"attachments": {},
|
732 |
+
"cell_type": "markdown",
|
733 |
+
"metadata": {},
|
734 |
+
"source": [
|
735 |
+
"## Aligning the transcription with the original audio using Wav2Vec2\n",
|
736 |
+
"---\n",
|
737 |
+
"The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.\n",
|
738 |
+
"\n",
|
739 |
+
"The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
|
740 |
+
"\n",
|
741 |
+
"By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.\n",
|
742 |
+
"\n",
|
743 |
+
"If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead."
|
744 |
+
]
|
745 |
+
},
|
746 |
+
{
|
747 |
+
"cell_type": "code",
|
748 |
+
"execution_count": null,
|
749 |
+
"metadata": {},
|
750 |
+
"outputs": [],
|
751 |
+
"source": [
|
752 |
+
"alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(\n",
|
753 |
+
" device,\n",
|
754 |
+
" dtype=torch.float16 if device == \"cuda\" else torch.float32,\n",
|
755 |
+
")\n",
|
756 |
+
"\n",
|
757 |
+
"audio_waveform = (\n",
|
758 |
+
" torch.from_numpy(audio_waveform)\n",
|
759 |
+
" .to(alignment_model.dtype)\n",
|
760 |
+
" .to(alignment_model.device)\n",
|
761 |
+
")\n",
|
762 |
+
"\n",
|
763 |
+
"emissions, stride = generate_emissions(\n",
|
764 |
+
" alignment_model, audio_waveform, batch_size=batch_size\n",
|
765 |
+
")\n",
|
766 |
+
"\n",
|
767 |
+
"del alignment_model\n",
|
768 |
+
"torch.cuda.empty_cache()\n",
|
769 |
+
"\n",
|
770 |
+
"full_transcript = \"\".join(segment[\"text\"] for segment in whisper_results)\n",
|
771 |
+
"\n",
|
772 |
+
"tokens_starred, text_starred = preprocess_text(\n",
|
773 |
+
" full_transcript,\n",
|
774 |
+
" romanize=True,\n",
|
775 |
+
" language=langs_to_iso[language],\n",
|
776 |
+
")\n",
|
777 |
+
"\n",
|
778 |
+
"segments, scores, blank_id = get_alignments(\n",
|
779 |
+
" emissions,\n",
|
780 |
+
" tokens_starred,\n",
|
781 |
+
" alignment_dictionary,\n",
|
782 |
+
")\n",
|
783 |
+
"\n",
|
784 |
+
"spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))\n",
|
785 |
+
"\n",
|
786 |
+
"word_timestamps = postprocess_results(text_starred, spans, stride, scores)"
|
787 |
+
]
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"attachments": {},
|
791 |
+
"cell_type": "markdown",
|
792 |
+
"metadata": {
|
793 |
+
"id": "7EEaJPsQ21Rx"
|
794 |
+
},
|
795 |
+
"source": [
|
796 |
+
"## Convert audio to mono for NeMo combatibility"
|
797 |
+
]
|
798 |
+
},
|
799 |
+
{
|
800 |
+
"cell_type": "code",
|
801 |
+
"execution_count": null,
|
802 |
+
"metadata": {},
|
803 |
+
"outputs": [],
|
804 |
+
"source": [
|
805 |
+
"ROOT = os.getcwd()\n",
|
806 |
+
"temp_path = os.path.join(ROOT, \"temp_outputs\")\n",
|
807 |
+
"os.makedirs(temp_path, exist_ok=True)\n",
|
808 |
+
"torchaudio.save(\n",
|
809 |
+
" os.path.join(temp_path, \"mono_file.wav\"),\n",
|
810 |
+
" audio_waveform.cpu().unsqueeze(0).float(),\n",
|
811 |
+
" 16000,\n",
|
812 |
+
" channels_first=True,\n",
|
813 |
+
")"
|
814 |
+
]
|
815 |
+
},
|
816 |
+
{
|
817 |
+
"attachments": {},
|
818 |
+
"cell_type": "markdown",
|
819 |
+
"metadata": {
|
820 |
+
"id": "D1gkViCf2-CV"
|
821 |
+
},
|
822 |
+
"source": [
|
823 |
+
"## Speaker Diarization using NeMo MSDD Model\n",
|
824 |
+
"---\n",
|
825 |
+
"This code uses a model called Nvidia NeMo MSDD (Multi-scale Diarization Decoder) to perform speaker diarization on an audio signal. Speaker diarization is the process of separating an audio signal into different segments based on who is speaking at any given time."
|
826 |
+
]
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"cell_type": "code",
|
830 |
+
"execution_count": null,
|
831 |
+
"metadata": {
|
832 |
+
"id": "C7jIpBCH02RL"
|
833 |
+
},
|
834 |
+
"outputs": [],
|
835 |
+
"source": [
|
836 |
+
"# Initialize NeMo MSDD diarization model\n",
|
837 |
+
"msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(\"cuda\")\n",
|
838 |
+
"msdd_model.diarize()\n",
|
839 |
+
"\n",
|
840 |
+
"del msdd_model\n",
|
841 |
+
"torch.cuda.empty_cache()"
|
842 |
+
]
|
843 |
+
},
|
844 |
+
{
|
845 |
+
"attachments": {},
|
846 |
+
"cell_type": "markdown",
|
847 |
+
"metadata": {
|
848 |
+
"id": "NmkZYaDAEOAg"
|
849 |
+
},
|
850 |
+
"source": [
|
851 |
+
"## Mapping Spekers to Sentences According to Timestamps"
|
852 |
+
]
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "code",
|
856 |
+
"execution_count": null,
|
857 |
+
"metadata": {
|
858 |
+
"id": "E65LUGQe02zw"
|
859 |
+
},
|
860 |
+
"outputs": [],
|
861 |
+
"source": [
|
862 |
+
"# Reading timestamps <> Speaker Labels mapping\n",
|
863 |
+
"\n",
|
864 |
+
"speaker_ts = []\n",
|
865 |
+
"with open(os.path.join(temp_path, \"pred_rttms\", \"mono_file.rttm\"), \"r\") as f:\n",
|
866 |
+
" lines = f.readlines()\n",
|
867 |
+
" for line in lines:\n",
|
868 |
+
" line_list = line.split(\" \")\n",
|
869 |
+
" s = int(float(line_list[5]) * 1000)\n",
|
870 |
+
" e = s + int(float(line_list[8]) * 1000)\n",
|
871 |
+
" speaker_ts.append([s, e, int(line_list[11].split(\"_\")[-1])])\n",
|
872 |
+
"\n",
|
873 |
+
"wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, \"start\")"
|
874 |
+
]
|
875 |
+
},
|
876 |
+
{
|
877 |
+
"attachments": {},
|
878 |
+
"cell_type": "markdown",
|
879 |
+
"metadata": {
|
880 |
+
"id": "8Ruxc8S1EXtW"
|
881 |
+
},
|
882 |
+
"source": [
|
883 |
+
"## Realligning Speech segments using Punctuation\n",
|
884 |
+
"---\n",
|
885 |
+
"\n",
|
886 |
+
"This code provides a method for disambiguating speaker labels in cases where a sentence is split between two different speakers. It uses punctuation markings to determine the dominant speaker for each sentence in the transcription.\n",
|
887 |
+
"\n",
|
888 |
+
"```\n",
|
889 |
+
"Speaker A: It's got to come from somewhere else. Yeah, that one's also fun because you know the lows are\n",
|
890 |
+
"Speaker B: going to suck, right? So it's actually it hits you on both sides.\n",
|
891 |
+
"```\n",
|
892 |
+
"\n",
|
893 |
+
"For example, if a sentence is split between two speakers, the code takes the mode of speaker labels for each word in the sentence, and uses that speaker label for the whole sentence. This can help to improve the accuracy of speaker diarization, especially in cases where the Whisper model may not take fine utterances like \"hmm\" and \"yeah\" into account, but the Diarization Model (Nemo) may include them, leading to inconsistent results.\n",
|
894 |
+
"\n",
|
895 |
+
"The code also handles cases where one speaker is giving a monologue while other speakers are making occasional comments in the background. It ignores the comments and assigns the entire monologue to the speaker who is speaking the majority of the time. This provides a robust and reliable method for realigning speech segments to their respective speakers based on punctuation in the transcription."
|
896 |
+
]
|
897 |
+
},
|
898 |
+
{
|
899 |
+
"cell_type": "code",
|
900 |
+
"execution_count": null,
|
901 |
+
"metadata": {
|
902 |
+
"id": "pgfC5hA41BXu"
|
903 |
+
},
|
904 |
+
"outputs": [],
|
905 |
+
"source": [
|
906 |
+
"if language in punct_model_langs:\n",
|
907 |
+
" # restoring punctuation in the transcript to help realign the sentences\n",
|
908 |
+
" punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
|
909 |
+
"\n",
|
910 |
+
" words_list = list(map(lambda x: x[\"word\"], wsm))\n",
|
911 |
+
"\n",
|
912 |
+
" labled_words = punct_model.predict(words_list,chunk_size=230)\n",
|
913 |
+
"\n",
|
914 |
+
" ending_puncts = \".?!\"\n",
|
915 |
+
" model_puncts = \".,;:!?\"\n",
|
916 |
+
"\n",
|
917 |
+
" # We don't want to punctuate U.S.A. with a period. Right?\n",
|
918 |
+
" is_acronym = lambda x: re.fullmatch(r\"\\b(?:[a-zA-Z]\\.){2,}\", x)\n",
|
919 |
+
"\n",
|
920 |
+
" for word_dict, labeled_tuple in zip(wsm, labled_words):\n",
|
921 |
+
" word = word_dict[\"word\"]\n",
|
922 |
+
" if (\n",
|
923 |
+
" word\n",
|
924 |
+
" and labeled_tuple[1] in ending_puncts\n",
|
925 |
+
" and (word[-1] not in model_puncts or is_acronym(word))\n",
|
926 |
+
" ):\n",
|
927 |
+
" word += labeled_tuple[1]\n",
|
928 |
+
" if word.endswith(\"..\"):\n",
|
929 |
+
" word = word.rstrip(\".\")\n",
|
930 |
+
" word_dict[\"word\"] = word\n",
|
931 |
+
"\n",
|
932 |
+
"else:\n",
|
933 |
+
" logging.warning(\n",
|
934 |
+
" f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
|
935 |
+
" )\n",
|
936 |
+
"\n",
|
937 |
+
"wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
|
938 |
+
"ssm = get_sentences_speaker_mapping(wsm, speaker_ts)"
|
939 |
+
]
|
940 |
+
},
|
941 |
+
{
|
942 |
+
"attachments": {},
|
943 |
+
"cell_type": "markdown",
|
944 |
+
"metadata": {
|
945 |
+
"id": "vF2QAtLOFvwZ"
|
946 |
+
},
|
947 |
+
"source": [
|
948 |
+
"## Cleanup and Exporing the results"
|
949 |
+
]
|
950 |
+
},
|
951 |
+
{
|
952 |
+
"cell_type": "code",
|
953 |
+
"execution_count": null,
|
954 |
+
"metadata": {
|
955 |
+
"id": "kFTyKI6B1MI0"
|
956 |
+
},
|
957 |
+
"outputs": [],
|
958 |
+
"source": [
|
959 |
+
"with open(f\"{os.path.splitext(audio_path)[0]}.txt\", \"w\", encoding=\"utf-8-sig\") as f:\n",
|
960 |
+
" get_speaker_aware_transcript(ssm, f)\n",
|
961 |
+
"\n",
|
962 |
+
"with open(f\"{os.path.splitext(audio_path)[0]}.srt\", \"w\", encoding=\"utf-8-sig\") as srt:\n",
|
963 |
+
" write_srt(ssm, srt)\n",
|
964 |
+
"\n",
|
965 |
+
"cleanup(temp_path)"
|
966 |
+
]
|
967 |
+
}
|
968 |
+
],
|
969 |
+
"metadata": {
|
970 |
+
"accelerator": "GPU",
|
971 |
+
"colab": {
|
972 |
+
"authorship_tag": "ABX9TyOyiQNkD+ROzss634BOsrSh",
|
973 |
+
"collapsed_sections": [
|
974 |
+
"eCmjcOc9yEtQ",
|
975 |
+
"jbsUt3SwyhjD"
|
976 |
+
],
|
977 |
+
"include_colab_link": true,
|
978 |
+
"provenance": []
|
979 |
+
},
|
980 |
+
"gpuClass": "standard",
|
981 |
+
"kernelspec": {
|
982 |
+
"display_name": "Python 3",
|
983 |
+
"name": "python3"
|
984 |
+
},
|
985 |
+
"language_info": {
|
986 |
+
"codemirror_mode": {
|
987 |
+
"name": "ipython",
|
988 |
+
"version": 3
|
989 |
+
},
|
990 |
+
"file_extension": ".py",
|
991 |
+
"mimetype": "text/x-python",
|
992 |
+
"name": "python",
|
993 |
+
"nbconvert_exporter": "python",
|
994 |
+
"pygments_lexer": "ipython3",
|
995 |
+
"version": "3.10.12"
|
996 |
+
}
|
997 |
+
},
|
998 |
+
"nbformat": 4,
|
999 |
+
"nbformat_minor": 0
|
1000 |
+
}
|
whisper_diarization_main/diarize.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
from ctc_forced_aligner import (
|
9 |
+
generate_emissions,
|
10 |
+
get_alignments,
|
11 |
+
get_spans,
|
12 |
+
load_alignment_model,
|
13 |
+
postprocess_results,
|
14 |
+
preprocess_text,
|
15 |
+
)
|
16 |
+
from deepmultilingualpunctuation import PunctuationModel
|
17 |
+
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
|
18 |
+
|
19 |
+
from helpers import (
|
20 |
+
cleanup,
|
21 |
+
create_config,
|
22 |
+
get_realigned_ws_mapping_with_punctuation,
|
23 |
+
get_sentences_speaker_mapping,
|
24 |
+
get_speaker_aware_transcript,
|
25 |
+
get_words_speaker_mapping,
|
26 |
+
langs_to_iso,
|
27 |
+
punct_model_langs,
|
28 |
+
whisper_langs,
|
29 |
+
write_srt,
|
30 |
+
)
|
31 |
+
from transcription_helpers import transcribe_batched
|
32 |
+
|
33 |
+
mtypes = {"cpu": "int8", "cuda": "float16"}
|
34 |
+
|
35 |
+
# Initialize parser
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
parser.add_argument(
|
38 |
+
"-a", "--audio", help="name of the target audio file", required=True
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--no-stem",
|
42 |
+
action="store_false",
|
43 |
+
dest="stemming",
|
44 |
+
default=True,
|
45 |
+
help="Disables source separation."
|
46 |
+
"This helps with long files that don't contain a lot of music.",
|
47 |
+
)
|
48 |
+
|
49 |
+
parser.add_argument(
|
50 |
+
"--suppress_numerals",
|
51 |
+
action="store_true",
|
52 |
+
dest="suppress_numerals",
|
53 |
+
default=False,
|
54 |
+
help="Suppresses Numerical Digits."
|
55 |
+
"This helps the diarization accuracy but converts all digits into written text.",
|
56 |
+
)
|
57 |
+
|
58 |
+
parser.add_argument(
|
59 |
+
"--whisper-model",
|
60 |
+
dest="model_name",
|
61 |
+
default="medium.en",
|
62 |
+
help="name of the Whisper model to use",
|
63 |
+
)
|
64 |
+
|
65 |
+
parser.add_argument(
|
66 |
+
"--batch-size",
|
67 |
+
type=int,
|
68 |
+
dest="batch_size",
|
69 |
+
default=8,
|
70 |
+
help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
|
71 |
+
)
|
72 |
+
|
73 |
+
parser.add_argument(
|
74 |
+
"--language",
|
75 |
+
type=str,
|
76 |
+
default=None,
|
77 |
+
choices=whisper_langs,
|
78 |
+
help="Language spoken in the audio, specify None to perform language detection",
|
79 |
+
)
|
80 |
+
|
81 |
+
parser.add_argument(
|
82 |
+
"--device",
|
83 |
+
dest="device",
|
84 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
85 |
+
help="if you have a GPU use 'cuda', otherwise 'cpu'",
|
86 |
+
)
|
87 |
+
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
if args.stemming:
|
91 |
+
# Isolate vocals from the rest of the audio
|
92 |
+
|
93 |
+
return_code = os.system(
|
94 |
+
f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
|
95 |
+
)
|
96 |
+
|
97 |
+
if return_code != 0:
|
98 |
+
logging.warning(
|
99 |
+
"Source splitting failed, using original audio file. Use --no-stem argument to disable it."
|
100 |
+
)
|
101 |
+
vocal_target = args.audio
|
102 |
+
else:
|
103 |
+
vocal_target = os.path.join(
|
104 |
+
"temp_outputs",
|
105 |
+
"htdemucs",
|
106 |
+
os.path.splitext(os.path.basename(args.audio))[0],
|
107 |
+
"vocals.wav",
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
vocal_target = args.audio
|
111 |
+
|
112 |
+
|
113 |
+
# Transcribe the audio file
|
114 |
+
|
115 |
+
whisper_results, language, audio_waveform = transcribe_batched(
|
116 |
+
vocal_target,
|
117 |
+
args.language,
|
118 |
+
args.batch_size,
|
119 |
+
args.model_name,
|
120 |
+
mtypes[args.device],
|
121 |
+
args.suppress_numerals,
|
122 |
+
args.device,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Forced Alignment
|
126 |
+
alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(
|
127 |
+
args.device,
|
128 |
+
dtype=torch.float16 if args.device == "cuda" else torch.float32,
|
129 |
+
)
|
130 |
+
|
131 |
+
audio_waveform = (
|
132 |
+
torch.from_numpy(audio_waveform)
|
133 |
+
.to(alignment_model.dtype)
|
134 |
+
.to(alignment_model.device)
|
135 |
+
)
|
136 |
+
emissions, stride = generate_emissions(
|
137 |
+
alignment_model, audio_waveform, batch_size=args.batch_size
|
138 |
+
)
|
139 |
+
|
140 |
+
del alignment_model
|
141 |
+
torch.cuda.empty_cache()
|
142 |
+
|
143 |
+
full_transcript = "".join(segment["text"] for segment in whisper_results)
|
144 |
+
|
145 |
+
tokens_starred, text_starred = preprocess_text(
|
146 |
+
full_transcript,
|
147 |
+
romanize=True,
|
148 |
+
language=langs_to_iso[language],
|
149 |
+
)
|
150 |
+
|
151 |
+
segments, scores, blank_id = get_alignments(
|
152 |
+
emissions,
|
153 |
+
tokens_starred,
|
154 |
+
alignment_dictionary,
|
155 |
+
)
|
156 |
+
|
157 |
+
spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))
|
158 |
+
|
159 |
+
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
|
160 |
+
|
161 |
+
|
162 |
+
# convert audio to mono for NeMo combatibility
|
163 |
+
ROOT = os.getcwd()
|
164 |
+
temp_path = os.path.join(ROOT, "temp_outputs")
|
165 |
+
os.makedirs(temp_path, exist_ok=True)
|
166 |
+
torchaudio.save(
|
167 |
+
os.path.join(temp_path, "mono_file.wav"),
|
168 |
+
audio_waveform.cpu().unsqueeze(0).float(),
|
169 |
+
16000,
|
170 |
+
channels_first=True,
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
# Initialize NeMo MSDD diarization model
|
175 |
+
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device)
|
176 |
+
msdd_model.diarize()
|
177 |
+
|
178 |
+
del msdd_model
|
179 |
+
torch.cuda.empty_cache()
|
180 |
+
|
181 |
+
# Reading timestamps <> Speaker Labels mapping
|
182 |
+
|
183 |
+
|
184 |
+
speaker_ts = []
|
185 |
+
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
|
186 |
+
lines = f.readlines()
|
187 |
+
for line in lines:
|
188 |
+
line_list = line.split(" ")
|
189 |
+
s = int(float(line_list[5]) * 1000)
|
190 |
+
e = s + int(float(line_list[8]) * 1000)
|
191 |
+
speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
|
192 |
+
|
193 |
+
wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
|
194 |
+
|
195 |
+
if language in punct_model_langs:
|
196 |
+
# restoring punctuation in the transcript to help realign the sentences
|
197 |
+
punct_model = PunctuationModel(model="kredor/punctuate-all")
|
198 |
+
|
199 |
+
words_list = list(map(lambda x: x["word"], wsm))
|
200 |
+
|
201 |
+
labled_words = punct_model.predict(words_list, chunk_size=230)
|
202 |
+
|
203 |
+
ending_puncts = ".?!"
|
204 |
+
model_puncts = ".,;:!?"
|
205 |
+
|
206 |
+
# We don't want to punctuate U.S.A. with a period. Right?
|
207 |
+
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
|
208 |
+
|
209 |
+
for word_dict, labeled_tuple in zip(wsm, labled_words):
|
210 |
+
word = word_dict["word"]
|
211 |
+
if (
|
212 |
+
word
|
213 |
+
and labeled_tuple[1] in ending_puncts
|
214 |
+
and (word[-1] not in model_puncts or is_acronym(word))
|
215 |
+
):
|
216 |
+
word += labeled_tuple[1]
|
217 |
+
if word.endswith(".."):
|
218 |
+
word = word.rstrip(".")
|
219 |
+
word_dict["word"] = word
|
220 |
+
|
221 |
+
else:
|
222 |
+
logging.warning(
|
223 |
+
f"Punctuation restoration is not available for {language} language. Using the original punctuation."
|
224 |
+
)
|
225 |
+
|
226 |
+
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
|
227 |
+
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
|
228 |
+
|
229 |
+
with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
|
230 |
+
get_speaker_aware_transcript(ssm, f)
|
231 |
+
|
232 |
+
with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt:
|
233 |
+
write_srt(ssm, srt)
|
234 |
+
|
235 |
+
cleanup(temp_path)
|
whisper_diarization_main/diarize_parallel.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from ctc_forced_aligner import (
|
9 |
+
generate_emissions,
|
10 |
+
get_alignments,
|
11 |
+
get_spans,
|
12 |
+
load_alignment_model,
|
13 |
+
postprocess_results,
|
14 |
+
preprocess_text,
|
15 |
+
)
|
16 |
+
from deepmultilingualpunctuation import PunctuationModel
|
17 |
+
|
18 |
+
from helpers import (
|
19 |
+
cleanup,
|
20 |
+
get_realigned_ws_mapping_with_punctuation,
|
21 |
+
get_sentences_speaker_mapping,
|
22 |
+
get_speaker_aware_transcript,
|
23 |
+
get_words_speaker_mapping,
|
24 |
+
langs_to_iso,
|
25 |
+
punct_model_langs,
|
26 |
+
whisper_langs,
|
27 |
+
write_srt,
|
28 |
+
)
|
29 |
+
from transcription_helpers import transcribe_batched
|
30 |
+
|
31 |
+
mtypes = {"cpu": "int8", "cuda": "float16"}
|
32 |
+
|
33 |
+
# Initialize parser
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument(
|
36 |
+
"-a", "--audio", help="name of the target audio file", required=True
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--no-stem",
|
40 |
+
action="store_false",
|
41 |
+
dest="stemming",
|
42 |
+
default=True,
|
43 |
+
help="Disables source separation."
|
44 |
+
"This helps with long files that don't contain a lot of music.",
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--suppress_numerals",
|
49 |
+
action="store_true",
|
50 |
+
dest="suppress_numerals",
|
51 |
+
default=False,
|
52 |
+
help="Suppresses Numerical Digits."
|
53 |
+
"This helps the diarization accuracy but converts all digits into written text.",
|
54 |
+
)
|
55 |
+
|
56 |
+
parser.add_argument(
|
57 |
+
"--whisper-model",
|
58 |
+
dest="model_name",
|
59 |
+
default="medium.en",
|
60 |
+
help="name of the Whisper model to use",
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
"--batch-size",
|
65 |
+
type=int,
|
66 |
+
dest="batch_size",
|
67 |
+
default=8,
|
68 |
+
help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
|
69 |
+
)
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
"--language",
|
73 |
+
type=str,
|
74 |
+
default=None,
|
75 |
+
choices=whisper_langs,
|
76 |
+
help="Language spoken in the audio, specify None to perform language detection",
|
77 |
+
)
|
78 |
+
|
79 |
+
parser.add_argument(
|
80 |
+
"--device",
|
81 |
+
dest="device",
|
82 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
83 |
+
help="if you have a GPU use 'cuda', otherwise 'cpu'",
|
84 |
+
)
|
85 |
+
|
86 |
+
args = parser.parse_args()
|
87 |
+
|
88 |
+
if args.stemming:
|
89 |
+
# Isolate vocals from the rest of the audio
|
90 |
+
|
91 |
+
return_code = os.system(
|
92 |
+
f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
|
93 |
+
)
|
94 |
+
|
95 |
+
if return_code != 0:
|
96 |
+
logging.warning(
|
97 |
+
"Source splitting failed, using original audio file. Use --no-stem argument to disable it."
|
98 |
+
)
|
99 |
+
vocal_target = args.audio
|
100 |
+
else:
|
101 |
+
vocal_target = os.path.join(
|
102 |
+
"temp_outputs",
|
103 |
+
"htdemucs",
|
104 |
+
os.path.splitext(os.path.basename(args.audio))[0],
|
105 |
+
"vocals.wav",
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
vocal_target = args.audio
|
109 |
+
|
110 |
+
logging.info("Starting Nemo process with vocal_target: ", vocal_target)
|
111 |
+
nemo_process = subprocess.Popen(
|
112 |
+
["python3", "nemo_process.py", "-a", vocal_target, "--device", args.device],
|
113 |
+
)
|
114 |
+
# Transcribe the audio file
|
115 |
+
whisper_results, language, audio_waveform = transcribe_batched(
|
116 |
+
vocal_target,
|
117 |
+
args.language,
|
118 |
+
args.batch_size,
|
119 |
+
args.model_name,
|
120 |
+
mtypes[args.device],
|
121 |
+
args.suppress_numerals,
|
122 |
+
args.device,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Forced Alignment
|
126 |
+
alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(
|
127 |
+
args.device,
|
128 |
+
dtype=torch.float16 if args.device == "cuda" else torch.float32,
|
129 |
+
)
|
130 |
+
|
131 |
+
audio_waveform = (
|
132 |
+
torch.from_numpy(audio_waveform)
|
133 |
+
.to(alignment_model.dtype)
|
134 |
+
.to(alignment_model.device)
|
135 |
+
)
|
136 |
+
emissions, stride = generate_emissions(
|
137 |
+
alignment_model, audio_waveform, batch_size=args.batch_size
|
138 |
+
)
|
139 |
+
|
140 |
+
del alignment_model
|
141 |
+
torch.cuda.empty_cache()
|
142 |
+
|
143 |
+
full_transcript = "".join(segment["text"] for segment in whisper_results)
|
144 |
+
|
145 |
+
tokens_starred, text_starred = preprocess_text(
|
146 |
+
full_transcript,
|
147 |
+
romanize=True,
|
148 |
+
language=langs_to_iso[language],
|
149 |
+
)
|
150 |
+
|
151 |
+
segments, scores, blank_id = get_alignments(
|
152 |
+
emissions,
|
153 |
+
tokens_starred,
|
154 |
+
alignment_dictionary,
|
155 |
+
)
|
156 |
+
|
157 |
+
spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))
|
158 |
+
|
159 |
+
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
|
160 |
+
|
161 |
+
# Reading timestamps <> Speaker Labels mapping
|
162 |
+
nemo_process.communicate()
|
163 |
+
ROOT = os.getcwd()
|
164 |
+
temp_path = os.path.join(ROOT, "temp_outputs")
|
165 |
+
|
166 |
+
speaker_ts = []
|
167 |
+
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
|
168 |
+
lines = f.readlines()
|
169 |
+
for line in lines:
|
170 |
+
line_list = line.split(" ")
|
171 |
+
s = int(float(line_list[5]) * 1000)
|
172 |
+
e = s + int(float(line_list[8]) * 1000)
|
173 |
+
speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
|
174 |
+
|
175 |
+
wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
|
176 |
+
|
177 |
+
if language in punct_model_langs:
|
178 |
+
# restoring punctuation in the transcript to help realign the sentences
|
179 |
+
punct_model = PunctuationModel(model="kredor/punctuate-all")
|
180 |
+
|
181 |
+
words_list = list(map(lambda x: x["word"], wsm))
|
182 |
+
|
183 |
+
labled_words = punct_model.predict(words_list, chunk_size=230)
|
184 |
+
|
185 |
+
ending_puncts = ".?!"
|
186 |
+
model_puncts = ".,;:!?"
|
187 |
+
|
188 |
+
# We don't want to punctuate U.S.A. with a period. Right?
|
189 |
+
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
|
190 |
+
|
191 |
+
for word_dict, labeled_tuple in zip(wsm, labled_words):
|
192 |
+
word = word_dict["word"]
|
193 |
+
if (
|
194 |
+
word
|
195 |
+
and labeled_tuple[1] in ending_puncts
|
196 |
+
and (word[-1] not in model_puncts or is_acronym(word))
|
197 |
+
):
|
198 |
+
word += labeled_tuple[1]
|
199 |
+
if word.endswith(".."):
|
200 |
+
word = word.rstrip(".")
|
201 |
+
word_dict["word"] = word
|
202 |
+
|
203 |
+
else:
|
204 |
+
logging.warning(
|
205 |
+
f"Punctuation restoration is not available for {language} language. Using the original punctuation."
|
206 |
+
)
|
207 |
+
|
208 |
+
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
|
209 |
+
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
|
210 |
+
|
211 |
+
with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
|
212 |
+
get_speaker_aware_transcript(ssm, f)
|
213 |
+
|
214 |
+
with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt:
|
215 |
+
write_srt(ssm, srt)
|
216 |
+
|
217 |
+
cleanup(temp_path)
|
whisper_diarization_main/helpers.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
import nltk
|
7 |
+
import wget
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH
|
10 |
+
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE
|
11 |
+
|
12 |
+
punct_model_langs = [
|
13 |
+
"en",
|
14 |
+
"fr",
|
15 |
+
"de",
|
16 |
+
"es",
|
17 |
+
"it",
|
18 |
+
"nl",
|
19 |
+
"pt",
|
20 |
+
"bg",
|
21 |
+
"pl",
|
22 |
+
"cs",
|
23 |
+
"sk",
|
24 |
+
"sl",
|
25 |
+
]
|
26 |
+
wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list(
|
27 |
+
DEFAULT_ALIGN_MODELS_HF.keys()
|
28 |
+
)
|
29 |
+
|
30 |
+
whisper_langs = sorted(LANGUAGES.keys()) + sorted(
|
31 |
+
[k.title() for k in TO_LANGUAGE_CODE.keys()]
|
32 |
+
)
|
33 |
+
|
34 |
+
langs_to_iso = {
|
35 |
+
"aa": "aar",
|
36 |
+
"ab": "abk",
|
37 |
+
"ae": "ave",
|
38 |
+
"af": "afr",
|
39 |
+
"ak": "aka",
|
40 |
+
"am": "amh",
|
41 |
+
"an": "arg",
|
42 |
+
"ar": "ara",
|
43 |
+
"as": "asm",
|
44 |
+
"av": "ava",
|
45 |
+
"ay": "aym",
|
46 |
+
"az": "aze",
|
47 |
+
"ba": "bak",
|
48 |
+
"be": "bel",
|
49 |
+
"bg": "bul",
|
50 |
+
"bh": "bih",
|
51 |
+
"bi": "bis",
|
52 |
+
"bm": "bam",
|
53 |
+
"bn": "ben",
|
54 |
+
"bo": "tib",
|
55 |
+
"br": "bre",
|
56 |
+
"bs": "bos",
|
57 |
+
"ca": "cat",
|
58 |
+
"ce": "che",
|
59 |
+
"ch": "cha",
|
60 |
+
"co": "cos",
|
61 |
+
"cr": "cre",
|
62 |
+
"cs": "cze",
|
63 |
+
"cu": "chu",
|
64 |
+
"cv": "chv",
|
65 |
+
"cy": "wel",
|
66 |
+
"da": "dan",
|
67 |
+
"de": "ger",
|
68 |
+
"dv": "div",
|
69 |
+
"dz": "dzo",
|
70 |
+
"ee": "ewe",
|
71 |
+
"el": "gre",
|
72 |
+
"en": "eng",
|
73 |
+
"eo": "epo",
|
74 |
+
"es": "spa",
|
75 |
+
"et": "est",
|
76 |
+
"eu": "baq",
|
77 |
+
"fa": "per",
|
78 |
+
"ff": "ful",
|
79 |
+
"fi": "fin",
|
80 |
+
"fj": "fij",
|
81 |
+
"fo": "fao",
|
82 |
+
"fr": "fre",
|
83 |
+
"fy": "fry",
|
84 |
+
"ga": "gle",
|
85 |
+
"gd": "gla",
|
86 |
+
"gl": "glg",
|
87 |
+
"gn": "grn",
|
88 |
+
"gu": "guj",
|
89 |
+
"gv": "glv",
|
90 |
+
"ha": "hau",
|
91 |
+
"he": "heb",
|
92 |
+
"hi": "hin",
|
93 |
+
"ho": "hmo",
|
94 |
+
"hr": "hrv",
|
95 |
+
"ht": "hat",
|
96 |
+
"hu": "hun",
|
97 |
+
"hy": "arm",
|
98 |
+
"hz": "her",
|
99 |
+
"ia": "ina",
|
100 |
+
"id": "ind",
|
101 |
+
"ie": "ile",
|
102 |
+
"ig": "ibo",
|
103 |
+
"ii": "iii",
|
104 |
+
"ik": "ipk",
|
105 |
+
"io": "ido",
|
106 |
+
"is": "ice",
|
107 |
+
"it": "ita",
|
108 |
+
"iu": "iku",
|
109 |
+
"ja": "jpn",
|
110 |
+
"jv": "jav",
|
111 |
+
"ka": "geo",
|
112 |
+
"kg": "kon",
|
113 |
+
"ki": "kik",
|
114 |
+
"kj": "kua",
|
115 |
+
"kk": "kaz",
|
116 |
+
"kl": "kal",
|
117 |
+
"km": "khm",
|
118 |
+
"kn": "kan",
|
119 |
+
"ko": "kor",
|
120 |
+
"kr": "kau",
|
121 |
+
"ks": "kas",
|
122 |
+
"ku": "kur",
|
123 |
+
"kv": "kom",
|
124 |
+
"kw": "cor",
|
125 |
+
"ky": "kir",
|
126 |
+
"la": "lat",
|
127 |
+
"lb": "ltz",
|
128 |
+
"lg": "lug",
|
129 |
+
"li": "lim",
|
130 |
+
"ln": "lin",
|
131 |
+
"lo": "lao",
|
132 |
+
"lt": "lit",
|
133 |
+
"lu": "lub",
|
134 |
+
"lv": "lav",
|
135 |
+
"mg": "mlg",
|
136 |
+
"mh": "mah",
|
137 |
+
"mi": "mao",
|
138 |
+
"mk": "mac",
|
139 |
+
"ml": "mal",
|
140 |
+
"mn": "mon",
|
141 |
+
"mr": "mar",
|
142 |
+
"ms": "may",
|
143 |
+
"mt": "mlt",
|
144 |
+
"my": "bur",
|
145 |
+
"na": "nau",
|
146 |
+
"nb": "nob",
|
147 |
+
"nd": "nde",
|
148 |
+
"ne": "nep",
|
149 |
+
"ng": "ndo",
|
150 |
+
"nl": "dut",
|
151 |
+
"nn": "nno",
|
152 |
+
"no": "nor",
|
153 |
+
"nr": "nbl",
|
154 |
+
"nv": "nav",
|
155 |
+
"ny": "nya",
|
156 |
+
"oc": "oci",
|
157 |
+
"oj": "oji",
|
158 |
+
"om": "orm",
|
159 |
+
"or": "ori",
|
160 |
+
"os": "oss",
|
161 |
+
"pa": "pan",
|
162 |
+
"pi": "pli",
|
163 |
+
"pl": "pol",
|
164 |
+
"ps": "pus",
|
165 |
+
"pt": "por",
|
166 |
+
"qu": "que",
|
167 |
+
"rm": "roh",
|
168 |
+
"rn": "run",
|
169 |
+
"ro": "rum",
|
170 |
+
"ru": "rus",
|
171 |
+
"rw": "kin",
|
172 |
+
"sa": "san",
|
173 |
+
"sc": "srd",
|
174 |
+
"sd": "snd",
|
175 |
+
"se": "sme",
|
176 |
+
"sg": "sag",
|
177 |
+
"si": "sin",
|
178 |
+
"sk": "slo",
|
179 |
+
"sl": "slv",
|
180 |
+
"sm": "smo",
|
181 |
+
"sn": "sna",
|
182 |
+
"so": "som",
|
183 |
+
"sq": "alb",
|
184 |
+
"sr": "srp",
|
185 |
+
"ss": "ssw",
|
186 |
+
"st": "sot",
|
187 |
+
"su": "sun",
|
188 |
+
"sv": "swe",
|
189 |
+
"sw": "swa",
|
190 |
+
"ta": "tam",
|
191 |
+
"te": "tel",
|
192 |
+
"tg": "tgk",
|
193 |
+
"th": "tha",
|
194 |
+
"ti": "tir",
|
195 |
+
"tk": "tuk",
|
196 |
+
"tl": "tgl",
|
197 |
+
"tn": "tsn",
|
198 |
+
"to": "ton",
|
199 |
+
"tr": "tur",
|
200 |
+
"ts": "tso",
|
201 |
+
"tt": "tat",
|
202 |
+
"tw": "twi",
|
203 |
+
"ty": "tah",
|
204 |
+
"ug": "uig",
|
205 |
+
"uk": "ukr",
|
206 |
+
"ur": "urd",
|
207 |
+
"uz": "uzb",
|
208 |
+
"ve": "ven",
|
209 |
+
"vi": "vie",
|
210 |
+
"vo": "vol",
|
211 |
+
"wa": "wln",
|
212 |
+
"wo": "wol",
|
213 |
+
"xh": "xho",
|
214 |
+
"yi": "yid",
|
215 |
+
"yo": "yor",
|
216 |
+
"za": "zha",
|
217 |
+
"zh": "chi",
|
218 |
+
"zu": "zul",
|
219 |
+
}
|
220 |
+
|
221 |
+
|
222 |
+
def create_config(output_dir):
|
223 |
+
DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file
|
224 |
+
CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs"
|
225 |
+
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
|
226 |
+
MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME)
|
227 |
+
if not os.path.exists(MODEL_CONFIG_PATH):
|
228 |
+
os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True)
|
229 |
+
CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
|
230 |
+
MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH)
|
231 |
+
|
232 |
+
config = OmegaConf.load(MODEL_CONFIG_PATH)
|
233 |
+
|
234 |
+
data_dir = os.path.join(output_dir, "data")
|
235 |
+
os.makedirs(data_dir, exist_ok=True)
|
236 |
+
|
237 |
+
meta = {
|
238 |
+
"audio_filepath": os.path.join(output_dir, "mono_file.wav"),
|
239 |
+
"offset": 0,
|
240 |
+
"duration": None,
|
241 |
+
"label": "infer",
|
242 |
+
"text": "-",
|
243 |
+
"rttm_filepath": None,
|
244 |
+
"uem_filepath": None,
|
245 |
+
}
|
246 |
+
with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
|
247 |
+
json.dump(meta, fp)
|
248 |
+
fp.write("\n")
|
249 |
+
|
250 |
+
pretrained_vad = "vad_multilingual_marblenet"
|
251 |
+
pretrained_speaker_model = "titanet_large"
|
252 |
+
config.num_workers = 0
|
253 |
+
config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
|
254 |
+
config.diarizer.out_dir = (
|
255 |
+
output_dir # Directory to store intermediate files and prediction outputs
|
256 |
+
)
|
257 |
+
|
258 |
+
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
|
259 |
+
config.diarizer.oracle_vad = (
|
260 |
+
False # compute VAD provided with model_path to vad config
|
261 |
+
)
|
262 |
+
config.diarizer.clustering.parameters.oracle_num_speakers = False
|
263 |
+
|
264 |
+
# Here, we use our in-house pretrained NeMo VAD model
|
265 |
+
config.diarizer.vad.model_path = pretrained_vad
|
266 |
+
config.diarizer.vad.parameters.onset = 0.8
|
267 |
+
config.diarizer.vad.parameters.offset = 0.6
|
268 |
+
config.diarizer.vad.parameters.pad_offset = -0.05
|
269 |
+
config.diarizer.msdd_model.model_path = (
|
270 |
+
"diar_msdd_telephonic" # Telephonic speaker diarization model
|
271 |
+
)
|
272 |
+
|
273 |
+
return config
|
274 |
+
|
275 |
+
|
276 |
+
def get_word_ts_anchor(s, e, option="start"):
|
277 |
+
if option == "end":
|
278 |
+
return e
|
279 |
+
elif option == "mid":
|
280 |
+
return (s + e) / 2
|
281 |
+
return s
|
282 |
+
|
283 |
+
|
284 |
+
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
|
285 |
+
s, e, sp = spk_ts[0]
|
286 |
+
wrd_pos, turn_idx = 0, 0
|
287 |
+
wrd_spk_mapping = []
|
288 |
+
for wrd_dict in wrd_ts:
|
289 |
+
ws, we, wrd = (
|
290 |
+
int(wrd_dict["start"] * 1000),
|
291 |
+
int(wrd_dict["end"] * 1000),
|
292 |
+
wrd_dict["text"],
|
293 |
+
)
|
294 |
+
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
|
295 |
+
while wrd_pos > float(e):
|
296 |
+
turn_idx += 1
|
297 |
+
turn_idx = min(turn_idx, len(spk_ts) - 1)
|
298 |
+
s, e, sp = spk_ts[turn_idx]
|
299 |
+
if turn_idx == len(spk_ts) - 1:
|
300 |
+
e = get_word_ts_anchor(ws, we, option="end")
|
301 |
+
wrd_spk_mapping.append(
|
302 |
+
{"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
|
303 |
+
)
|
304 |
+
return wrd_spk_mapping
|
305 |
+
|
306 |
+
|
307 |
+
sentence_ending_punctuations = ".?!"
|
308 |
+
|
309 |
+
|
310 |
+
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
|
311 |
+
is_word_sentence_end = (
|
312 |
+
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
|
313 |
+
)
|
314 |
+
left_idx = word_idx
|
315 |
+
while (
|
316 |
+
left_idx > 0
|
317 |
+
and word_idx - left_idx < max_words
|
318 |
+
and speaker_list[left_idx - 1] == speaker_list[left_idx]
|
319 |
+
and not is_word_sentence_end(left_idx - 1)
|
320 |
+
):
|
321 |
+
left_idx -= 1
|
322 |
+
|
323 |
+
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
|
324 |
+
|
325 |
+
|
326 |
+
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
|
327 |
+
is_word_sentence_end = (
|
328 |
+
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
|
329 |
+
)
|
330 |
+
right_idx = word_idx
|
331 |
+
while (
|
332 |
+
right_idx < len(word_list) - 1
|
333 |
+
and right_idx - word_idx < max_words
|
334 |
+
and not is_word_sentence_end(right_idx)
|
335 |
+
):
|
336 |
+
right_idx += 1
|
337 |
+
|
338 |
+
return (
|
339 |
+
right_idx
|
340 |
+
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
|
341 |
+
else -1
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def get_realigned_ws_mapping_with_punctuation(
|
346 |
+
word_speaker_mapping, max_words_in_sentence=50
|
347 |
+
):
|
348 |
+
is_word_sentence_end = (
|
349 |
+
lambda x: x >= 0
|
350 |
+
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
|
351 |
+
)
|
352 |
+
wsp_len = len(word_speaker_mapping)
|
353 |
+
|
354 |
+
words_list, speaker_list = [], []
|
355 |
+
for k, line_dict in enumerate(word_speaker_mapping):
|
356 |
+
word, speaker = line_dict["word"], line_dict["speaker"]
|
357 |
+
words_list.append(word)
|
358 |
+
speaker_list.append(speaker)
|
359 |
+
|
360 |
+
k = 0
|
361 |
+
while k < len(word_speaker_mapping):
|
362 |
+
line_dict = word_speaker_mapping[k]
|
363 |
+
if (
|
364 |
+
k < wsp_len - 1
|
365 |
+
and speaker_list[k] != speaker_list[k + 1]
|
366 |
+
and not is_word_sentence_end(k)
|
367 |
+
):
|
368 |
+
left_idx = get_first_word_idx_of_sentence(
|
369 |
+
k, words_list, speaker_list, max_words_in_sentence
|
370 |
+
)
|
371 |
+
right_idx = (
|
372 |
+
get_last_word_idx_of_sentence(
|
373 |
+
k, words_list, max_words_in_sentence - k + left_idx - 1
|
374 |
+
)
|
375 |
+
if left_idx > -1
|
376 |
+
else -1
|
377 |
+
)
|
378 |
+
if min(left_idx, right_idx) == -1:
|
379 |
+
k += 1
|
380 |
+
continue
|
381 |
+
|
382 |
+
spk_labels = speaker_list[left_idx : right_idx + 1]
|
383 |
+
mod_speaker = max(set(spk_labels), key=spk_labels.count)
|
384 |
+
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
|
385 |
+
k += 1
|
386 |
+
continue
|
387 |
+
|
388 |
+
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
|
389 |
+
right_idx - left_idx + 1
|
390 |
+
)
|
391 |
+
k = right_idx
|
392 |
+
|
393 |
+
k += 1
|
394 |
+
|
395 |
+
k, realigned_list = 0, []
|
396 |
+
while k < len(word_speaker_mapping):
|
397 |
+
line_dict = word_speaker_mapping[k].copy()
|
398 |
+
line_dict["speaker"] = speaker_list[k]
|
399 |
+
realigned_list.append(line_dict)
|
400 |
+
k += 1
|
401 |
+
|
402 |
+
return realigned_list
|
403 |
+
|
404 |
+
|
405 |
+
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
|
406 |
+
sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
|
407 |
+
s, e, spk = spk_ts[0]
|
408 |
+
prev_spk = spk
|
409 |
+
|
410 |
+
snts = []
|
411 |
+
snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}
|
412 |
+
|
413 |
+
for wrd_dict in word_speaker_mapping:
|
414 |
+
wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
|
415 |
+
s, e = wrd_dict["start_time"], wrd_dict["end_time"]
|
416 |
+
if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
|
417 |
+
snts.append(snt)
|
418 |
+
snt = {
|
419 |
+
"speaker": f"Speaker {spk}",
|
420 |
+
"start_time": s,
|
421 |
+
"end_time": e,
|
422 |
+
"text": "",
|
423 |
+
}
|
424 |
+
else:
|
425 |
+
snt["end_time"] = e
|
426 |
+
snt["text"] += wrd + " "
|
427 |
+
prev_spk = spk
|
428 |
+
|
429 |
+
snts.append(snt)
|
430 |
+
return snts
|
431 |
+
|
432 |
+
|
433 |
+
def get_speaker_aware_transcript(sentences_speaker_mapping, f):
|
434 |
+
previous_speaker = sentences_speaker_mapping[0]["speaker"]
|
435 |
+
f.write(f"{previous_speaker}: ")
|
436 |
+
|
437 |
+
for sentence_dict in sentences_speaker_mapping:
|
438 |
+
speaker = sentence_dict["speaker"]
|
439 |
+
sentence = sentence_dict["text"]
|
440 |
+
|
441 |
+
# If this speaker doesn't match the previous one, start a new paragraph
|
442 |
+
if speaker != previous_speaker:
|
443 |
+
f.write(f"\n\n{speaker}: ")
|
444 |
+
previous_speaker = speaker
|
445 |
+
|
446 |
+
# No matter what, write the current sentence
|
447 |
+
f.write(sentence + " ")
|
448 |
+
|
449 |
+
|
450 |
+
def format_timestamp(
|
451 |
+
milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
452 |
+
):
|
453 |
+
assert milliseconds >= 0, "non-negative timestamp expected"
|
454 |
+
|
455 |
+
hours = milliseconds // 3_600_000
|
456 |
+
milliseconds -= hours * 3_600_000
|
457 |
+
|
458 |
+
minutes = milliseconds // 60_000
|
459 |
+
milliseconds -= minutes * 60_000
|
460 |
+
|
461 |
+
seconds = milliseconds // 1_000
|
462 |
+
milliseconds -= seconds * 1_000
|
463 |
+
|
464 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
465 |
+
return (
|
466 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
467 |
+
)
|
468 |
+
|
469 |
+
|
470 |
+
def write_srt(transcript, file):
|
471 |
+
"""
|
472 |
+
Write a transcript to a file in SRT format.
|
473 |
+
|
474 |
+
"""
|
475 |
+
for i, segment in enumerate(transcript, start=1):
|
476 |
+
# write srt lines
|
477 |
+
print(
|
478 |
+
f"{i}\n"
|
479 |
+
f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
|
480 |
+
f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
|
481 |
+
f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
|
482 |
+
file=file,
|
483 |
+
flush=True,
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
def find_numeral_symbol_tokens(tokenizer):
|
488 |
+
numeral_symbol_tokens = [
|
489 |
+
-1,
|
490 |
+
]
|
491 |
+
for token, token_id in tokenizer.get_vocab().items():
|
492 |
+
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
|
493 |
+
if has_numeral_symbol:
|
494 |
+
numeral_symbol_tokens.append(token_id)
|
495 |
+
return numeral_symbol_tokens
|
496 |
+
|
497 |
+
|
498 |
+
def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
|
499 |
+
# if current word is the last word
|
500 |
+
if current_word_index == len(word_timestamps) - 1:
|
501 |
+
return word_timestamps[current_word_index]["start"]
|
502 |
+
|
503 |
+
next_word_index = current_word_index + 1
|
504 |
+
while current_word_index < len(word_timestamps) - 1:
|
505 |
+
if word_timestamps[next_word_index].get("start") is None:
|
506 |
+
# if next word doesn't have a start timestamp
|
507 |
+
# merge it with the current word and delete it
|
508 |
+
word_timestamps[current_word_index]["word"] += (
|
509 |
+
" " + word_timestamps[next_word_index]["word"]
|
510 |
+
)
|
511 |
+
|
512 |
+
word_timestamps[next_word_index]["word"] = None
|
513 |
+
next_word_index += 1
|
514 |
+
if next_word_index == len(word_timestamps):
|
515 |
+
return final_timestamp
|
516 |
+
|
517 |
+
else:
|
518 |
+
return word_timestamps[next_word_index]["start"]
|
519 |
+
|
520 |
+
|
521 |
+
def filter_missing_timestamps(
|
522 |
+
word_timestamps, initial_timestamp=0, final_timestamp=None
|
523 |
+
):
|
524 |
+
# handle the first and last word
|
525 |
+
if word_timestamps[0].get("start") is None:
|
526 |
+
word_timestamps[0]["start"] = (
|
527 |
+
initial_timestamp if initial_timestamp is not None else 0
|
528 |
+
)
|
529 |
+
word_timestamps[0]["end"] = _get_next_start_timestamp(
|
530 |
+
word_timestamps, 0, final_timestamp
|
531 |
+
)
|
532 |
+
|
533 |
+
result = [
|
534 |
+
word_timestamps[0],
|
535 |
+
]
|
536 |
+
|
537 |
+
for i, ws in enumerate(word_timestamps[1:], start=1):
|
538 |
+
# if ws doesn't have a start and end
|
539 |
+
# use the previous end as start and next start as end
|
540 |
+
if ws.get("start") is None and ws.get("word") is not None:
|
541 |
+
ws["start"] = word_timestamps[i - 1]["end"]
|
542 |
+
ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)
|
543 |
+
|
544 |
+
if ws["word"] is not None:
|
545 |
+
result.append(ws)
|
546 |
+
return result
|
547 |
+
|
548 |
+
|
549 |
+
def cleanup(path: str):
|
550 |
+
"""path could either be relative or absolute."""
|
551 |
+
# check if file or directory exists
|
552 |
+
if os.path.isfile(path) or os.path.islink(path):
|
553 |
+
# remove file
|
554 |
+
os.remove(path)
|
555 |
+
elif os.path.isdir(path):
|
556 |
+
# remove directory and all its content
|
557 |
+
shutil.rmtree(path)
|
558 |
+
else:
|
559 |
+
raise ValueError("Path {} is not a file or dir.".format(path))
|
560 |
+
|
561 |
+
|
562 |
+
def process_language_arg(language: str, model_name: str):
|
563 |
+
"""
|
564 |
+
Process the language argument to make sure it's valid and convert language names to language codes.
|
565 |
+
"""
|
566 |
+
if language is not None:
|
567 |
+
language = language.lower()
|
568 |
+
if language not in LANGUAGES:
|
569 |
+
if language in TO_LANGUAGE_CODE:
|
570 |
+
language = TO_LANGUAGE_CODE[language]
|
571 |
+
else:
|
572 |
+
raise ValueError(f"Unsupported language: {language}")
|
573 |
+
|
574 |
+
if model_name.endswith(".en") and language != "en":
|
575 |
+
if language is not None:
|
576 |
+
logging.warning(
|
577 |
+
f"{model_name} is an English-only model but received '{language}'; using English instead."
|
578 |
+
)
|
579 |
+
language = "en"
|
580 |
+
return language
|
whisper_diarization_main/nemo_msdd_configs/diar_infer_general.yaml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
|
2 |
+
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
|
3 |
+
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
|
4 |
+
# The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set.
|
5 |
+
# An example line in an input manifest file (`.json` format):
|
6 |
+
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
|
7 |
+
name: &name "ClusterDiarizer"
|
8 |
+
|
9 |
+
num_workers: 1
|
10 |
+
sample_rate: 16000
|
11 |
+
batch_size: 64
|
12 |
+
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
|
13 |
+
verbose: True # enable additional logging
|
14 |
+
|
15 |
+
diarizer:
|
16 |
+
manifest_filepath: ???
|
17 |
+
out_dir: ???
|
18 |
+
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
|
19 |
+
collar: 0.25 # Collar value for scoring
|
20 |
+
ignore_overlap: True # Consider or ignore overlap segments while scoring
|
21 |
+
|
22 |
+
vad:
|
23 |
+
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
|
24 |
+
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
|
25 |
+
|
26 |
+
parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets
|
27 |
+
window_length_in_sec: 0.63 # Window length in sec for VAD context input
|
28 |
+
shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction
|
29 |
+
smoothing: False # False or type of smoothing method (eg: median)
|
30 |
+
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
|
31 |
+
onset: 0.5 # Onset threshold for detecting the beginning and end of a speech
|
32 |
+
offset: 0.3 # Offset threshold for detecting the end of a speech
|
33 |
+
pad_onset: 0.2 # Adding durations before each speech segment
|
34 |
+
pad_offset: 0.2 # Adding durations after each speech segment
|
35 |
+
min_duration_on: 0.5 # Threshold for small non_speech deletion
|
36 |
+
min_duration_off: 0.5 # Threshold for short speech segment deletion
|
37 |
+
filter_speech_first: True
|
38 |
+
|
39 |
+
speaker_embeddings:
|
40 |
+
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
|
41 |
+
parameters:
|
42 |
+
window_length_in_sec: [1.9,1.2,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
|
43 |
+
shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
|
44 |
+
multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
|
45 |
+
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
|
46 |
+
|
47 |
+
clustering:
|
48 |
+
parameters:
|
49 |
+
oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
|
50 |
+
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
|
51 |
+
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
|
52 |
+
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
|
53 |
+
sparse_search_volume: 10 # The higher the number, the more values will be examined with more time.
|
54 |
+
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
|
55 |
+
chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
|
56 |
+
embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
|
57 |
+
|
58 |
+
|
59 |
+
msdd_model:
|
60 |
+
model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
|
61 |
+
parameters:
|
62 |
+
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
|
63 |
+
infer_batch_size: 25 # Batch size for MSDD inference.
|
64 |
+
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
|
65 |
+
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
|
66 |
+
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
|
67 |
+
diar_window_length: 50 # The length of split short sequence when split_infer is True.
|
68 |
+
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
|
69 |
+
|
70 |
+
asr:
|
71 |
+
model_path: null # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
|
72 |
+
parameters:
|
73 |
+
asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
|
74 |
+
asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
|
75 |
+
asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
|
76 |
+
decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
|
77 |
+
word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
|
78 |
+
word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
|
79 |
+
fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
|
80 |
+
colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
|
81 |
+
print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
|
82 |
+
break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
|
83 |
+
|
84 |
+
ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
|
85 |
+
pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
|
86 |
+
beam_width: 32
|
87 |
+
alpha: 0.5
|
88 |
+
beta: 2.5
|
89 |
+
|
90 |
+
realigning_lm_parameters: # Experimental feature
|
91 |
+
arpa_language_model: null # Provide a KenLM language model in .arpa format.
|
92 |
+
min_number_of_words: 3 # Min number of words for the left context.
|
93 |
+
max_number_of_words: 10 # Max number of words for the right context.
|
94 |
+
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
|
95 |
+
|
whisper_diarization_main/nemo_msdd_configs/diar_infer_meeting.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
|
2 |
+
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
|
3 |
+
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
|
4 |
+
# The configurations in this YAML file is suitable for 3~5 speakers participating in a meeting and may not show the best performance on other types of dialogues.
|
5 |
+
# An example line in an input manifest file (`.json` format):
|
6 |
+
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
|
7 |
+
name: &name "ClusterDiarizer"
|
8 |
+
|
9 |
+
num_workers: 1
|
10 |
+
sample_rate: 16000
|
11 |
+
batch_size: 64
|
12 |
+
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
|
13 |
+
verbose: True # enable additional logging
|
14 |
+
|
15 |
+
diarizer:
|
16 |
+
manifest_filepath: ???
|
17 |
+
out_dir: ???
|
18 |
+
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
|
19 |
+
collar: 0.25 # Collar value for scoring
|
20 |
+
ignore_overlap: True # Consider or ignore overlap segments while scoring
|
21 |
+
|
22 |
+
vad:
|
23 |
+
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
|
24 |
+
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
|
25 |
+
|
26 |
+
parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set)
|
27 |
+
window_length_in_sec: 0.63 # Window length in sec for VAD context input
|
28 |
+
shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction
|
29 |
+
smoothing: False # False or type of smoothing method (eg: median)
|
30 |
+
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
|
31 |
+
onset: 0.9 # Onset threshold for detecting the beginning and end of a speech
|
32 |
+
offset: 0.5 # Offset threshold for detecting the end of a speech
|
33 |
+
pad_onset: 0 # Adding durations before each speech segment
|
34 |
+
pad_offset: 0 # Adding durations after each speech segment
|
35 |
+
min_duration_on: 0 # Threshold for small non_speech deletion
|
36 |
+
min_duration_off: 0.6 # Threshold for short speech segment deletion
|
37 |
+
filter_speech_first: True
|
38 |
+
|
39 |
+
speaker_embeddings:
|
40 |
+
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
|
41 |
+
parameters:
|
42 |
+
window_length_in_sec: [3.0,2.5,2.0,1.5,1.0,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
|
43 |
+
shift_length_in_sec: [1.5,1.25,1.0,0.75,0.5,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
|
44 |
+
multiscale_weights: [1,1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
|
45 |
+
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
|
46 |
+
|
47 |
+
clustering:
|
48 |
+
parameters:
|
49 |
+
oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
|
50 |
+
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
|
51 |
+
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
|
52 |
+
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
|
53 |
+
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
|
54 |
+
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
|
55 |
+
chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
|
56 |
+
embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
|
57 |
+
|
58 |
+
msdd_model:
|
59 |
+
model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
|
60 |
+
parameters:
|
61 |
+
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
|
62 |
+
infer_batch_size: 25 # Batch size for MSDD inference.
|
63 |
+
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
|
64 |
+
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
|
65 |
+
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
|
66 |
+
diar_window_length: 50 # The length of split short sequence when split_infer is True.
|
67 |
+
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
|
68 |
+
|
69 |
+
asr:
|
70 |
+
model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
|
71 |
+
parameters:
|
72 |
+
asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
|
73 |
+
asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
|
74 |
+
asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
|
75 |
+
decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
|
76 |
+
word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
|
77 |
+
word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
|
78 |
+
fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
|
79 |
+
colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
|
80 |
+
print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
|
81 |
+
break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
|
82 |
+
|
83 |
+
ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
|
84 |
+
pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
|
85 |
+
beam_width: 32
|
86 |
+
alpha: 0.5
|
87 |
+
beta: 2.5
|
88 |
+
|
89 |
+
realigning_lm_parameters: # Experimental feature
|
90 |
+
arpa_language_model: null # Provide a KenLM language model in .arpa format.
|
91 |
+
min_number_of_words: 3 # Min number of words for the left context.
|
92 |
+
max_number_of_words: 10 # Max number of words for the right context.
|
93 |
+
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
|
94 |
+
|
whisper_diarization_main/nemo_msdd_configs/diar_infer_telephonic.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
|
2 |
+
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
|
3 |
+
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
|
4 |
+
# The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues.
|
5 |
+
# An example line in an input manifest file (`.json` format):
|
6 |
+
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
|
7 |
+
name: &name "ClusterDiarizer"
|
8 |
+
|
9 |
+
num_workers: 1
|
10 |
+
sample_rate: 16000
|
11 |
+
batch_size: 64
|
12 |
+
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
|
13 |
+
verbose: True # enable additional logging
|
14 |
+
|
15 |
+
diarizer:
|
16 |
+
manifest_filepath: ???
|
17 |
+
out_dir: ???
|
18 |
+
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
|
19 |
+
collar: 0.25 # Collar value for scoring
|
20 |
+
ignore_overlap: True # Consider or ignore overlap segments while scoring
|
21 |
+
|
22 |
+
vad:
|
23 |
+
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
|
24 |
+
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set
|
25 |
+
|
26 |
+
parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set)
|
27 |
+
window_length_in_sec: 0.15 # Window length in sec for VAD context input
|
28 |
+
shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction
|
29 |
+
smoothing: "median" # False or type of smoothing method (eg: median)
|
30 |
+
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
|
31 |
+
onset: 0.1 # Onset threshold for detecting the beginning and end of a speech
|
32 |
+
offset: 0.1 # Offset threshold for detecting the end of a speech
|
33 |
+
pad_onset: 0.1 # Adding durations before each speech segment
|
34 |
+
pad_offset: 0 # Adding durations after each speech segment
|
35 |
+
min_duration_on: 0 # Threshold for small non_speech deletion
|
36 |
+
min_duration_off: 0.2 # Threshold for short speech segment deletion
|
37 |
+
filter_speech_first: True
|
38 |
+
|
39 |
+
speaker_embeddings:
|
40 |
+
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
|
41 |
+
parameters:
|
42 |
+
window_length_in_sec: [1.5,1.25,1.0,0.75,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
|
43 |
+
shift_length_in_sec: [0.75,0.625,0.5,0.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
|
44 |
+
multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
|
45 |
+
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.
|
46 |
+
|
47 |
+
clustering:
|
48 |
+
parameters:
|
49 |
+
oracle_num_speakers: False # If True, use num of speakers value provided in manifest file.
|
50 |
+
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
|
51 |
+
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
|
52 |
+
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
|
53 |
+
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
|
54 |
+
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
|
55 |
+
chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
|
56 |
+
embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)
|
57 |
+
|
58 |
+
msdd_model:
|
59 |
+
model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
|
60 |
+
parameters:
|
61 |
+
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
|
62 |
+
infer_batch_size: 25 # Batch size for MSDD inference.
|
63 |
+
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
|
64 |
+
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
|
65 |
+
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
|
66 |
+
diar_window_length: 50 # The length of split short sequence when split_infer is True.
|
67 |
+
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.
|
68 |
+
|
69 |
+
asr:
|
70 |
+
model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes.
|
71 |
+
parameters:
|
72 |
+
asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference.
|
73 |
+
asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD.
|
74 |
+
asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null.
|
75 |
+
decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model.
|
76 |
+
word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2].
|
77 |
+
word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'.
|
78 |
+
fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature.
|
79 |
+
colored_text: False # If True, use colored text to distinguish speakers in the output transcript.
|
80 |
+
print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript.
|
81 |
+
break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars)
|
82 |
+
|
83 |
+
ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode)
|
84 |
+
pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file.
|
85 |
+
beam_width: 32
|
86 |
+
alpha: 0.5
|
87 |
+
beta: 2.5
|
88 |
+
|
89 |
+
realigning_lm_parameters: # Experimental feature
|
90 |
+
arpa_language_model: null # Provide a KenLM language model in .arpa format.
|
91 |
+
min_number_of_words: 3 # Min number of words for the left context.
|
92 |
+
max_number_of_words: 10 # Max number of words for the right context.
|
93 |
+
logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses.
|
94 |
+
|
whisper_diarization_main/nemo_process.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
|
6 |
+
from pydub import AudioSegment
|
7 |
+
|
8 |
+
from helpers import create_config
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument(
|
12 |
+
"-a", "--audio", help="name of the target audio file", required=True
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--device",
|
16 |
+
dest="device",
|
17 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
18 |
+
help="if you have a GPU use 'cuda', otherwise 'cpu'",
|
19 |
+
)
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# convert audio to mono for NeMo combatibility
|
23 |
+
sound = AudioSegment.from_file(args.audio).set_channels(1)
|
24 |
+
ROOT = os.getcwd()
|
25 |
+
temp_path = os.path.join(ROOT, "temp_outputs")
|
26 |
+
os.makedirs(temp_path, exist_ok=True)
|
27 |
+
sound.export(os.path.join(temp_path, "mono_file.wav"), format="wav")
|
28 |
+
|
29 |
+
# Initialize NeMo MSDD diarization model
|
30 |
+
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device)
|
31 |
+
msdd_model.diarize()
|
whisper_diarization_main/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wget
|
2 |
+
nemo_toolkit[asr]==1.23.0
|
3 |
+
git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560
|
4 |
+
git+https://github.com/adefossez/demucs.git
|
5 |
+
git+https://github.com/oliverguhr/deepmultilingualpunctuation.git
|
6 |
+
git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
7 |
+
|
whisper_diarization_main/transcription_helpers.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def transcribe(
|
5 |
+
audio_file: str,
|
6 |
+
language: str,
|
7 |
+
model_name: str,
|
8 |
+
compute_dtype: str,
|
9 |
+
suppress_numerals: bool,
|
10 |
+
device: str,
|
11 |
+
):
|
12 |
+
from faster_whisper import WhisperModel
|
13 |
+
|
14 |
+
from helpers import find_numeral_symbol_tokens, wav2vec2_langs
|
15 |
+
|
16 |
+
# Faster Whisper non-batched
|
17 |
+
# Run on GPU with FP16
|
18 |
+
whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype)
|
19 |
+
|
20 |
+
# or run on GPU with INT8
|
21 |
+
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
22 |
+
# or run on CPU with INT8
|
23 |
+
# model = WhisperModel(model_size, device="cpu", compute_type="int8")
|
24 |
+
|
25 |
+
if suppress_numerals:
|
26 |
+
numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
|
27 |
+
else:
|
28 |
+
numeral_symbol_tokens = None
|
29 |
+
|
30 |
+
if language is not None and language in wav2vec2_langs:
|
31 |
+
word_timestamps = False
|
32 |
+
else:
|
33 |
+
word_timestamps = True
|
34 |
+
|
35 |
+
segments, info = whisper_model.transcribe(
|
36 |
+
audio_file,
|
37 |
+
language=language,
|
38 |
+
beam_size=5,
|
39 |
+
word_timestamps=word_timestamps, # TODO: disable this if the language is supported by wav2vec2
|
40 |
+
suppress_tokens=numeral_symbol_tokens,
|
41 |
+
vad_filter=True,
|
42 |
+
)
|
43 |
+
whisper_results = []
|
44 |
+
for segment in segments:
|
45 |
+
whisper_results.append(segment._asdict())
|
46 |
+
# clear gpu vram
|
47 |
+
del whisper_model
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
return whisper_results, info.language
|
50 |
+
|
51 |
+
|
52 |
+
def transcribe_batched(
|
53 |
+
audio_file: str,
|
54 |
+
language: str,
|
55 |
+
batch_size: int,
|
56 |
+
model_name: str,
|
57 |
+
compute_dtype: str,
|
58 |
+
suppress_numerals: bool,
|
59 |
+
device: str,
|
60 |
+
):
|
61 |
+
import whisperx
|
62 |
+
|
63 |
+
# Faster Whisper batched
|
64 |
+
whisper_model = whisperx.load_model(
|
65 |
+
model_name,
|
66 |
+
device,
|
67 |
+
compute_type=compute_dtype,
|
68 |
+
asr_options={"suppress_numerals": suppress_numerals},
|
69 |
+
)
|
70 |
+
audio = whisperx.load_audio(audio_file)
|
71 |
+
result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)
|
72 |
+
del whisper_model
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
return result["segments"], result["language"], audio
|