Add support for parallel execution on multiple GPUs
Browse files- app.py +37 -14
- cli.py +2 -0
- src/vad.py +42 -24
- src/vadParallel.py +81 -0
- src/whisperContainer.py +91 -0
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
from typing import Iterator
|
|
|
2 |
|
3 |
from io import StringIO
|
4 |
import os
|
5 |
import pathlib
|
6 |
import tempfile
|
|
|
|
|
|
|
7 |
|
8 |
# External programs
|
9 |
import whisper
|
@@ -14,7 +18,7 @@ import gradio as gr
|
|
14 |
|
15 |
from src.download import ExceededMaximumDuration, download_url
|
16 |
from src.utils import slugify, write_srt, write_vtt
|
17 |
-
from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
18 |
|
19 |
# Limitations (set to -1 to disable)
|
20 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
@@ -48,6 +52,7 @@ LANGUAGES = [
|
|
48 |
class WhisperTranscriber:
|
49 |
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
|
50 |
self.model_cache = dict()
|
|
|
51 |
|
52 |
self.vad_model = None
|
53 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
@@ -64,7 +69,7 @@ class WhisperTranscriber:
|
|
64 |
model = self.model_cache.get(selectedModel, None)
|
65 |
|
66 |
if not model:
|
67 |
-
model =
|
68 |
self.model_cache[selectedModel] = model
|
69 |
|
70 |
# Execute whisper
|
@@ -87,7 +92,7 @@ class WhisperTranscriber:
|
|
87 |
except ExceededMaximumDuration as e:
|
88 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
89 |
|
90 |
-
def transcribe_file(self, model:
|
91 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
92 |
|
93 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
@@ -96,35 +101,42 @@ class WhisperTranscriber:
|
|
96 |
task = decodeOptions.pop('task')
|
97 |
|
98 |
# Callable for processing an audio file
|
99 |
-
whisperCallable =
|
100 |
-
language=language if language else detected_language, task=task, \
|
101 |
-
initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
|
102 |
-
**decodeOptions)
|
103 |
|
104 |
# The results
|
105 |
if (vad == 'silero-vad'):
|
106 |
# Silero VAD where non-speech gaps are transcribed
|
107 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
108 |
-
result = self.
|
109 |
elif (vad == 'silero-vad-skip-gaps'):
|
110 |
# Silero VAD where non-speech gaps are simply ignored
|
111 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
112 |
-
result = self.
|
113 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
114 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
115 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
116 |
-
result = self.
|
117 |
elif (vad == 'periodic-vad'):
|
118 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
119 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
120 |
periodic_vad = VadPeriodicTranscription()
|
121 |
-
|
|
|
|
|
122 |
else:
|
123 |
# Default VAD
|
124 |
result = whisperCallable(audio_path, 0, None, None)
|
125 |
|
126 |
return result
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def _concat_prompt(self, prompt1, prompt2):
|
129 |
if (prompt1 is None):
|
130 |
return prompt2
|
@@ -218,9 +230,12 @@ class WhisperTranscriber:
|
|
218 |
return file.name
|
219 |
|
220 |
|
221 |
-
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
|
222 |
ui = WhisperTranscriber(inputAudioMaxDuration)
|
223 |
|
|
|
|
|
|
|
224 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
225 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
226 |
ui_description += " as well as speech translation and language identification. "
|
@@ -250,7 +265,15 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
|
|
250 |
gr.Text(label="Segments")
|
251 |
])
|
252 |
|
253 |
-
demo.launch(share=share, server_name=server_name)
|
254 |
|
255 |
if __name__ == '__main__':
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Iterator
|
2 |
+
import argparse
|
3 |
|
4 |
from io import StringIO
|
5 |
import os
|
6 |
import pathlib
|
7 |
import tempfile
|
8 |
+
from src.vadParallel import ParallelTranscription
|
9 |
+
|
10 |
+
from src.whisperContainer import WhisperContainer
|
11 |
|
12 |
# External programs
|
13 |
import whisper
|
|
|
18 |
|
19 |
from src.download import ExceededMaximumDuration, download_url
|
20 |
from src.utils import slugify, write_srt, write_vtt
|
21 |
+
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
22 |
|
23 |
# Limitations (set to -1 to disable)
|
24 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
|
52 |
class WhisperTranscriber:
|
53 |
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
|
54 |
self.model_cache = dict()
|
55 |
+
self.parallel_device_list = None
|
56 |
|
57 |
self.vad_model = None
|
58 |
self.inputAudioMaxDuration = inputAudioMaxDuration
|
|
|
69 |
model = self.model_cache.get(selectedModel, None)
|
70 |
|
71 |
if not model:
|
72 |
+
model = WhisperContainer(selectedModel)
|
73 |
self.model_cache[selectedModel] = model
|
74 |
|
75 |
# Execute whisper
|
|
|
92 |
except ExceededMaximumDuration as e:
|
93 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
94 |
|
95 |
+
def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
96 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
97 |
|
98 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
|
|
101 |
task = decodeOptions.pop('task')
|
102 |
|
103 |
# Callable for processing an audio file
|
104 |
+
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
|
|
|
|
|
|
105 |
|
106 |
# The results
|
107 |
if (vad == 'silero-vad'):
|
108 |
# Silero VAD where non-speech gaps are transcribed
|
109 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
110 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
|
111 |
elif (vad == 'silero-vad-skip-gaps'):
|
112 |
# Silero VAD where non-speech gaps are simply ignored
|
113 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
114 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
|
115 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
116 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
117 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
118 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
|
119 |
elif (vad == 'periodic-vad'):
|
120 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
121 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
122 |
periodic_vad = VadPeriodicTranscription()
|
123 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
124 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
125 |
+
|
126 |
else:
|
127 |
# Default VAD
|
128 |
result = whisperCallable(audio_path, 0, None, None)
|
129 |
|
130 |
return result
|
131 |
|
132 |
+
def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
|
133 |
+
if (self.parallel_device_list is None or len(self.parallel_device_list) == 0):
|
134 |
+
# No parallel devices, so just run the VAD and Whisper in sequence
|
135 |
+
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
136 |
+
|
137 |
+
parallell_vad = ParallelTranscription()
|
138 |
+
return parallell_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable, config=vadConfig, devices=self.parallel_device_list)
|
139 |
+
|
140 |
def _concat_prompt(self, prompt1, prompt2):
|
141 |
if (prompt1 is None):
|
142 |
return prompt2
|
|
|
230 |
return file.name
|
231 |
|
232 |
|
233 |
+
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, server_port: int = 7860, vad_parallel_devices: str = None):
|
234 |
ui = WhisperTranscriber(inputAudioMaxDuration)
|
235 |
|
236 |
+
# Specify a list of devices to use for parallel processing
|
237 |
+
ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
238 |
+
|
239 |
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
240 |
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
241 |
ui_description += " as well as speech translation and language identification. "
|
|
|
265 |
gr.Text(label="Segments")
|
266 |
])
|
267 |
|
268 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
269 |
|
270 |
if __name__ == '__main__':
|
271 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
272 |
+
parser.add_argument("--inputAudioMaxDuration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
|
273 |
+
parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
|
274 |
+
parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
|
275 |
+
parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
|
276 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="0,1", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
277 |
+
|
278 |
+
args = parser.parse_args().__dict__
|
279 |
+
create_ui(**args)
|
cli.py
CHANGED
@@ -31,6 +31,7 @@ def cli():
|
|
31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
|
|
34 |
|
35 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
36 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
@@ -74,6 +75,7 @@ def cli():
|
|
74 |
|
75 |
model = whisper.load_model(model_name, device=device, download_root=model_dir)
|
76 |
transcriber = WhisperTranscriber(deleteUploadedFiles=False)
|
|
|
77 |
|
78 |
for audio_path in args.pop("audio"):
|
79 |
sources = []
|
|
|
31 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
32 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
33 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
34 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="0", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
|
35 |
|
36 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
37 |
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
|
|
75 |
|
76 |
model = whisper.load_model(model_name, device=device, download_root=model_dir)
|
77 |
transcriber = WhisperTranscriber(deleteUploadedFiles=False)
|
78 |
+
transcriber.parallel_device_list = args.pop("vad_parallel_devices")
|
79 |
|
80 |
for audio_path in args.pop("audio"):
|
81 |
sources = []
|
src/vad.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Deque, Iterator, List, Dict
|
|
6 |
from pprint import pprint
|
7 |
|
8 |
from src.segments import merge_timestamps
|
|
|
9 |
|
10 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
11 |
try:
|
@@ -51,19 +52,20 @@ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
|
51 |
class TranscriptionConfig(ABC):
|
52 |
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
53 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
54 |
-
max_merge_size: float = None, max_prompt_window: float = None):
|
55 |
self.non_speech_strategy = non_speech_strategy
|
56 |
self.segment_padding_left = segment_padding_left
|
57 |
self.segment_padding_right = segment_padding_right
|
58 |
self.max_silent_period = max_silent_period
|
59 |
self.max_merge_size = max_merge_size
|
60 |
self.max_prompt_window = max_prompt_window
|
|
|
61 |
|
62 |
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
63 |
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
64 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
65 |
-
max_merge_size: float = None, max_prompt_window: float = None):
|
66 |
-
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
|
67 |
self.periodic_duration = periodic_duration
|
68 |
|
69 |
class AbstractTranscription(ABC):
|
@@ -91,37 +93,26 @@ class AbstractTranscription(ABC):
|
|
91 |
"""
|
92 |
return
|
93 |
|
94 |
-
def
|
95 |
"""
|
96 |
-
|
|
|
97 |
|
98 |
Parameters
|
99 |
----------
|
100 |
audio: str
|
101 |
-
The audio file.
|
102 |
-
|
103 |
-
|
104 |
-
The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
|
105 |
-
the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
|
106 |
|
107 |
Returns
|
108 |
-------
|
109 |
A list of start and end timestamps, in fractional seconds.
|
110 |
"""
|
111 |
-
|
112 |
-
# get speech timestamps from full audio file
|
113 |
seconds_timestamps = self.get_transcribe_timestamps(audio, config)
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
|
119 |
-
|
120 |
-
# A deque of transcribed segments that is passed to the next segment as a prompt
|
121 |
-
prompt_window = deque()
|
122 |
-
|
123 |
-
print("Timestamps:")
|
124 |
-
pprint(merged)
|
125 |
|
126 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
127 |
max_audio_duration = get_audio_duration(audio)
|
@@ -138,6 +129,32 @@ class AbstractTranscription(ABC):
|
|
138 |
|
139 |
print("Transcribing non-speech:")
|
140 |
pprint(merged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
result = {
|
143 |
'text': "",
|
@@ -147,7 +164,7 @@ class AbstractTranscription(ABC):
|
|
147 |
languageCounter = Counter()
|
148 |
detected_language = None
|
149 |
|
150 |
-
segment_index =
|
151 |
|
152 |
# For each time segment, run whisper
|
153 |
for segment in merged:
|
@@ -172,7 +189,7 @@ class AbstractTranscription(ABC):
|
|
172 |
|
173 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
174 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
175 |
-
segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
|
176 |
|
177 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
178 |
|
@@ -373,6 +390,7 @@ class AbstractTranscription(ABC):
|
|
373 |
})
|
374 |
return result
|
375 |
|
|
|
376 |
class VadSileroTranscription(AbstractTranscription):
|
377 |
def __init__(self, sampling_rate: int = 16000):
|
378 |
super().__init__(sampling_rate=sampling_rate)
|
|
|
6 |
from pprint import pprint
|
7 |
|
8 |
from src.segments import merge_timestamps
|
9 |
+
from src.whisperContainer import WhisperCallback
|
10 |
|
11 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
12 |
try:
|
|
|
52 |
class TranscriptionConfig(ABC):
|
53 |
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
54 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
55 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
56 |
self.non_speech_strategy = non_speech_strategy
|
57 |
self.segment_padding_left = segment_padding_left
|
58 |
self.segment_padding_right = segment_padding_right
|
59 |
self.max_silent_period = max_silent_period
|
60 |
self.max_merge_size = max_merge_size
|
61 |
self.max_prompt_window = max_prompt_window
|
62 |
+
self.initial_segment_index = initial_segment_index
|
63 |
|
64 |
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
65 |
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
66 |
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
67 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
68 |
+
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
|
69 |
self.periodic_duration = periodic_duration
|
70 |
|
71 |
class AbstractTranscription(ABC):
|
|
|
93 |
"""
|
94 |
return
|
95 |
|
96 |
+
def get_merged_timestamps(self, audio: str, config: TranscriptionConfig):
|
97 |
"""
|
98 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
99 |
+
after merging the segments using the specified configuration.
|
100 |
|
101 |
Parameters
|
102 |
----------
|
103 |
audio: str
|
104 |
+
The audio file.
|
105 |
+
config: TranscriptionConfig
|
106 |
+
The transcription configuration.
|
|
|
|
|
107 |
|
108 |
Returns
|
109 |
-------
|
110 |
A list of start and end timestamps, in fractional seconds.
|
111 |
"""
|
|
|
|
|
112 |
seconds_timestamps = self.get_transcribe_timestamps(audio, config)
|
113 |
|
114 |
+
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size,
|
115 |
+
config.segment_padding_left, config.segment_padding_right)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
118 |
max_audio_duration = get_audio_duration(audio)
|
|
|
129 |
|
130 |
print("Transcribing non-speech:")
|
131 |
pprint(merged)
|
132 |
+
return merged
|
133 |
+
|
134 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
|
135 |
+
"""
|
136 |
+
Transcribe the given audo file.
|
137 |
+
|
138 |
+
Parameters
|
139 |
+
----------
|
140 |
+
audio: str
|
141 |
+
The audio file.
|
142 |
+
whisperCallable: WhisperCallback
|
143 |
+
A callback object to call to transcribe each segment.
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
A list of start and end timestamps, in fractional seconds.
|
148 |
+
"""
|
149 |
+
|
150 |
+
# Get speech timestamps from full audio file
|
151 |
+
merged = self.get_merged_timestamps(audio, config)
|
152 |
+
|
153 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
154 |
+
prompt_window = deque()
|
155 |
+
|
156 |
+
print("Processing timestamps:")
|
157 |
+
pprint(merged)
|
158 |
|
159 |
result = {
|
160 |
'text': "",
|
|
|
164 |
languageCounter = Counter()
|
165 |
detected_language = None
|
166 |
|
167 |
+
segment_index = config.initial_segment_index
|
168 |
|
169 |
# For each time segment, run whisper
|
170 |
for segment in merged:
|
|
|
189 |
|
190 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
191 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
192 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
|
193 |
|
194 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
195 |
|
|
|
390 |
})
|
391 |
return result
|
392 |
|
393 |
+
|
394 |
class VadSileroTranscription(AbstractTranscription):
|
395 |
def __init__(self, sampling_rate: int = 16000):
|
396 |
super().__init__(sampling_rate=sampling_rate)
|
src/vadParallel.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.vad import AbstractTranscription, TranscriptionConfig
|
2 |
+
from src.whisperContainer import WhisperCallback
|
3 |
+
|
4 |
+
from multiprocessing import Pool
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
import os
|
8 |
+
|
9 |
+
class ParallelTranscriptionConfig(TranscriptionConfig):
|
10 |
+
def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
|
11 |
+
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
12 |
+
self.device_id = device_id
|
13 |
+
self.override_timestamps = override_timestamps
|
14 |
+
|
15 |
+
class ParallelTranscription(AbstractTranscription):
|
16 |
+
def __init__(self, sampling_rate: int = 16000):
|
17 |
+
super().__init__(sampling_rate=sampling_rate)
|
18 |
+
|
19 |
+
|
20 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str]):
|
21 |
+
# First, get the timestamps for the original audio
|
22 |
+
merged = transcription.get_merged_timestamps(audio, config)
|
23 |
+
|
24 |
+
# Split into a list for each device
|
25 |
+
merged_split = self._chunks(merged, len(merged) // len(devices))
|
26 |
+
|
27 |
+
# Parameters that will be passed to the transcribe function
|
28 |
+
parameters = []
|
29 |
+
segment_index = config.initial_segment_index
|
30 |
+
|
31 |
+
for i in range(len(devices)):
|
32 |
+
device_segment_list = merged_split[i]
|
33 |
+
|
34 |
+
# Create a new config with the given device ID
|
35 |
+
device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
|
36 |
+
segment_index += len(device_segment_list)
|
37 |
+
|
38 |
+
parameters.append([audio, whisperCallable, device_config]);
|
39 |
+
|
40 |
+
merged = {
|
41 |
+
'text': '',
|
42 |
+
'segments': [],
|
43 |
+
'language': None
|
44 |
+
}
|
45 |
+
|
46 |
+
with Pool(len(devices)) as p:
|
47 |
+
# Run the transcription in parallel
|
48 |
+
results = p.starmap(self.transcribe, parameters)
|
49 |
+
|
50 |
+
for result in results:
|
51 |
+
# Merge the results
|
52 |
+
if (result['text'] is not None):
|
53 |
+
merged['text'] += result['text']
|
54 |
+
if (result['segments'] is not None):
|
55 |
+
merged['segments'].extend(result['segments'])
|
56 |
+
if (result['language'] is not None):
|
57 |
+
merged['language'] = result['language']
|
58 |
+
|
59 |
+
return merged
|
60 |
+
|
61 |
+
def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
|
62 |
+
return []
|
63 |
+
|
64 |
+
def get_merged_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
|
65 |
+
# Override timestamps that will be processed
|
66 |
+
if (config.override_timestamps is not None):
|
67 |
+
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
68 |
+
return config.override_timestamps
|
69 |
+
return super().get_merged_timestamps(audio, config)
|
70 |
+
|
71 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
|
72 |
+
# Override device ID
|
73 |
+
if (config.device_id is not None):
|
74 |
+
print("Using device " + config.device_id)
|
75 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
76 |
+
return super().transcribe(audio, whisperCallable, config)
|
77 |
+
|
78 |
+
def _chunks(self, lst, n):
|
79 |
+
"""Yield successive n-sized chunks from lst."""
|
80 |
+
return [lst[i:i + n] for i in range(0, len(lst), n)]
|
81 |
+
|
src/whisperContainer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# External programs
|
2 |
+
import whisper
|
3 |
+
|
4 |
+
class WhisperContainer:
|
5 |
+
def __init__(self, model_name: str, device: str = None):
|
6 |
+
self.model_name = model_name
|
7 |
+
self.device = device
|
8 |
+
|
9 |
+
# Will be created on demand
|
10 |
+
self.model = None
|
11 |
+
|
12 |
+
def get_model(self):
|
13 |
+
if self.model is None:
|
14 |
+
print("Loading model " + self.model_name)
|
15 |
+
self.model = whisper.load_model(self.model_name, device=self.device)
|
16 |
+
return self.model
|
17 |
+
|
18 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
19 |
+
"""
|
20 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
language: str
|
25 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
26 |
+
task: str
|
27 |
+
The task - either translate or transcribe.
|
28 |
+
initial_prompt: str
|
29 |
+
The initial prompt to use for the transcription.
|
30 |
+
decodeOptions: dict
|
31 |
+
Additional options to pass to the decoder. Must be pickleable.
|
32 |
+
|
33 |
+
Returns
|
34 |
+
-------
|
35 |
+
A WhisperCallback object.
|
36 |
+
"""
|
37 |
+
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
38 |
+
|
39 |
+
# This is required for multiprocessing
|
40 |
+
def __getstate__(self):
|
41 |
+
return { "model_name": self.model_name, "device": self.device }
|
42 |
+
|
43 |
+
def __setstate__(self, state):
|
44 |
+
self.model_name = state["model_name"]
|
45 |
+
self.device = state["device"]
|
46 |
+
self.model = None
|
47 |
+
|
48 |
+
|
49 |
+
class WhisperCallback:
|
50 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
51 |
+
self.model_container = model_container
|
52 |
+
self.language = language
|
53 |
+
self.task = task
|
54 |
+
self.initial_prompt = initial_prompt
|
55 |
+
self.decodeOptions = decodeOptions
|
56 |
+
|
57 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
|
58 |
+
"""
|
59 |
+
Peform the transcription of the given audio file or data.
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
64 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
65 |
+
segment_index: int
|
66 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
67 |
+
task: str
|
68 |
+
The task - either translate or transcribe.
|
69 |
+
prompt: str
|
70 |
+
The prompt to use for the transcription.
|
71 |
+
detected_language: str
|
72 |
+
The detected language of the audio file.
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
The result of the Whisper call.
|
77 |
+
"""
|
78 |
+
model = self.model_container.get_model()
|
79 |
+
|
80 |
+
return model.transcribe(audio, \
|
81 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
82 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
83 |
+
**self.decodeOptions)
|
84 |
+
|
85 |
+
def _concat_prompt(self, prompt1, prompt2):
|
86 |
+
if (prompt1 is None):
|
87 |
+
return prompt2
|
88 |
+
elif (prompt2 is None):
|
89 |
+
return prompt1
|
90 |
+
else:
|
91 |
+
return prompt1 + " " + prompt2
|