Add progress listener to none/VAD
Browse filesNote that we don't handle progress of parallel transcription yet.
- app.py +46 -16
- src/hooks/whisperProgressHook.py +119 -0
- src/vad.py +8 -4
- src/whisperContainer.py +17 -4
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from datetime import datetime
|
2 |
import math
|
3 |
-
from typing import Iterator
|
4 |
import argparse
|
5 |
|
6 |
from io import StringIO
|
@@ -12,6 +12,7 @@ import numpy as np
|
|
12 |
|
13 |
import torch
|
14 |
from src.config import ApplicationConfig
|
|
|
15 |
from src.modelCache import ModelCache
|
16 |
from src.source import get_audio_source_collection
|
17 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
@@ -87,14 +88,17 @@ class WhisperTranscriber:
|
|
87 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
88 |
|
89 |
# Entry function for the simple tab
|
90 |
-
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow
|
91 |
-
|
|
|
|
|
92 |
|
93 |
# Entry function for the full tab
|
94 |
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
95 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
96 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
97 |
-
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float
|
|
|
98 |
|
99 |
# Handle temperature_increment_on_fallback
|
100 |
if temperature_increment_on_fallback is not None:
|
@@ -105,9 +109,11 @@ class WhisperTranscriber:
|
|
105 |
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
106 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
107 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
108 |
-
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold
|
|
|
109 |
|
110 |
-
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
|
|
111 |
try:
|
112 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
113 |
|
@@ -140,7 +146,7 @@ class WhisperTranscriber:
|
|
140 |
print("Transcribing ", source.source_path)
|
141 |
|
142 |
# Transcribe
|
143 |
-
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions)
|
144 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
145 |
|
146 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
|
@@ -202,7 +208,8 @@ class WhisperTranscriber:
|
|
202 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
203 |
|
204 |
def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
205 |
-
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
|
|
206 |
|
207 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
208 |
|
@@ -212,25 +219,28 @@ class WhisperTranscriber:
|
|
212 |
# Callable for processing an audio file
|
213 |
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
214 |
|
|
|
|
|
|
|
215 |
# The results
|
216 |
if (vad == 'silero-vad'):
|
217 |
# Silero VAD where non-speech gaps are transcribed
|
218 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
219 |
-
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
|
220 |
elif (vad == 'silero-vad-skip-gaps'):
|
221 |
# Silero VAD where non-speech gaps are simply ignored
|
222 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
223 |
-
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
|
224 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
225 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
226 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
227 |
-
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
|
228 |
elif (vad == 'periodic-vad'):
|
229 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
230 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
231 |
periodic_vad = VadPeriodicTranscription()
|
232 |
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
233 |
-
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
234 |
|
235 |
else:
|
236 |
if (self._has_parallel_devices()):
|
@@ -238,18 +248,38 @@ class WhisperTranscriber:
|
|
238 |
periodic_vad = VadPeriodicTranscription()
|
239 |
period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
|
240 |
|
241 |
-
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
242 |
else:
|
243 |
# Default VAD
|
244 |
-
result = whisperCallable.invoke(audio_path, 0, None, None)
|
245 |
|
246 |
return result
|
247 |
|
248 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
if (not self._has_parallel_devices()):
|
250 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
251 |
-
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
252 |
|
|
|
253 |
gpu_devices = self.parallel_device_list
|
254 |
|
255 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
|
|
1 |
from datetime import datetime
|
2 |
import math
|
3 |
+
from typing import Iterator, Union
|
4 |
import argparse
|
5 |
|
6 |
from io import StringIO
|
|
|
12 |
|
13 |
import torch
|
14 |
from src.config import ApplicationConfig
|
15 |
+
from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
|
16 |
from src.modelCache import ModelCache
|
17 |
from src.source import get_audio_source_collection
|
18 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
|
|
88 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
89 |
|
90 |
# Entry function for the simple tab
|
91 |
+
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
92 |
+
progress=gr.Progress()):
|
93 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
94 |
+
progress=progress)
|
95 |
|
96 |
# Entry function for the full tab
|
97 |
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
98 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
99 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
100 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
101 |
+
progress=gr.Progress()):
|
102 |
|
103 |
# Handle temperature_increment_on_fallback
|
104 |
if temperature_increment_on_fallback is not None:
|
|
|
109 |
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
110 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
111 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
112 |
+
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
113 |
+
progress=progress)
|
114 |
|
115 |
+
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
116 |
+
progress: gr.Progress = None, **decodeOptions: dict):
|
117 |
try:
|
118 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
119 |
|
|
|
146 |
print("Transcribing ", source.source_path)
|
147 |
|
148 |
# Transcribe
|
149 |
+
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, progress, **decodeOptions)
|
150 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
151 |
|
152 |
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
|
|
|
208 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
209 |
|
210 |
def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
211 |
+
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
212 |
+
progress: gr.Progress = None, **decodeOptions: dict):
|
213 |
|
214 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
215 |
|
|
|
219 |
# Callable for processing an audio file
|
220 |
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
221 |
|
222 |
+
# A listener that will report progress to Gradio
|
223 |
+
progressListener = self._create_progress_listener(progress)
|
224 |
+
|
225 |
# The results
|
226 |
if (vad == 'silero-vad'):
|
227 |
# Silero VAD where non-speech gaps are transcribed
|
228 |
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
229 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
|
230 |
elif (vad == 'silero-vad-skip-gaps'):
|
231 |
# Silero VAD where non-speech gaps are simply ignored
|
232 |
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
233 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
|
234 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
235 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
236 |
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
237 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
|
238 |
elif (vad == 'periodic-vad'):
|
239 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
240 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
241 |
periodic_vad = VadPeriodicTranscription()
|
242 |
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
243 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
244 |
|
245 |
else:
|
246 |
if (self._has_parallel_devices()):
|
|
|
248 |
periodic_vad = VadPeriodicTranscription()
|
249 |
period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
|
250 |
|
251 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
252 |
else:
|
253 |
# Default VAD
|
254 |
+
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
255 |
|
256 |
return result
|
257 |
|
258 |
+
def _create_progress_listener(self, progress: gr.Progress):
|
259 |
+
if (progress is None):
|
260 |
+
# Dummy progress listener
|
261 |
+
return ProgressListener()
|
262 |
+
|
263 |
+
class ForwardingProgressListener(ProgressListener):
|
264 |
+
def __init__(self, progress: gr.Progress):
|
265 |
+
self.progress = progress
|
266 |
+
|
267 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
268 |
+
# From 0 to 1
|
269 |
+
self.progress(current / total)
|
270 |
+
|
271 |
+
def on_finished(self):
|
272 |
+
self.progress(1)
|
273 |
+
|
274 |
+
return ForwardingProgressListener(progress)
|
275 |
+
|
276 |
+
def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
|
277 |
+
progressListener: ProgressListener = None):
|
278 |
if (not self._has_parallel_devices()):
|
279 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
280 |
+
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
281 |
|
282 |
+
# TODO: Handle progress listener
|
283 |
gpu_devices = self.parallel_device_list
|
284 |
|
285 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
src/hooks/whisperProgressHook.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import threading
|
3 |
+
from typing import List, Union
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
class ProgressListener:
|
7 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
8 |
+
self.total = total
|
9 |
+
|
10 |
+
def on_finished(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
class ProgressListenerHandle:
|
14 |
+
def __init__(self, listener: ProgressListener):
|
15 |
+
self.listener = listener
|
16 |
+
|
17 |
+
def __enter__(self):
|
18 |
+
register_thread_local_progress_listener(self.listener)
|
19 |
+
|
20 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
21 |
+
unregister_thread_local_progress_listener(self.listener)
|
22 |
+
|
23 |
+
if exc_type is None:
|
24 |
+
self.listener.on_finished()
|
25 |
+
|
26 |
+
class SubTaskProgressListener(ProgressListener):
|
27 |
+
"""
|
28 |
+
A sub task listener that reports the progress of a sub task to a base task listener
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
base_task_listener : ProgressListener
|
33 |
+
The base progress listener to accumulate overall progress in.
|
34 |
+
base_task_total : float
|
35 |
+
The maximum total progress that will be reported to the base progress listener.
|
36 |
+
sub_task_start : float
|
37 |
+
The starting progress of a sub task, in respect to the base progress listener.
|
38 |
+
sub_task_total : float
|
39 |
+
The total amount of progress a sub task will report to the base progress listener.
|
40 |
+
"""
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
base_task_listener: ProgressListener,
|
44 |
+
base_task_total: float,
|
45 |
+
sub_task_start: float,
|
46 |
+
sub_task_total: float,
|
47 |
+
):
|
48 |
+
self.base_task_listener = base_task_listener
|
49 |
+
self.base_task_total = base_task_total
|
50 |
+
self.sub_task_start = sub_task_start
|
51 |
+
self.sub_task_total = sub_task_total
|
52 |
+
|
53 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
54 |
+
sub_task_progress_frac = current / total
|
55 |
+
sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
|
56 |
+
self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
|
57 |
+
|
58 |
+
def on_finished(self):
|
59 |
+
self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
|
60 |
+
|
61 |
+
class _CustomProgressBar(tqdm.tqdm):
|
62 |
+
def __init__(self, *args, **kwargs):
|
63 |
+
super().__init__(*args, **kwargs)
|
64 |
+
self._current = self.n # Set the initial value
|
65 |
+
|
66 |
+
def update(self, n):
|
67 |
+
super().update(n)
|
68 |
+
# Because the progress bar might be disabled, we need to manually update the progress
|
69 |
+
self._current += n
|
70 |
+
|
71 |
+
# Inform listeners
|
72 |
+
listeners = _get_thread_local_listeners()
|
73 |
+
|
74 |
+
for listener in listeners:
|
75 |
+
listener.on_progress(self._current, self.total)
|
76 |
+
|
77 |
+
_thread_local = threading.local()
|
78 |
+
|
79 |
+
def _get_thread_local_listeners():
|
80 |
+
if not hasattr(_thread_local, 'listeners'):
|
81 |
+
_thread_local.listeners = []
|
82 |
+
return _thread_local.listeners
|
83 |
+
|
84 |
+
_hooked = False
|
85 |
+
|
86 |
+
def init_progress_hook():
|
87 |
+
global _hooked
|
88 |
+
|
89 |
+
if _hooked:
|
90 |
+
return
|
91 |
+
|
92 |
+
# Inject into tqdm.tqdm of Whisper, so we can see progress
|
93 |
+
import whisper.transcribe
|
94 |
+
transcribe_module = sys.modules['whisper.transcribe']
|
95 |
+
transcribe_module.tqdm.tqdm = _CustomProgressBar
|
96 |
+
_hooked = True
|
97 |
+
|
98 |
+
def register_thread_local_progress_listener(progress_listener: ProgressListener):
|
99 |
+
# This is a workaround for the fact that the progress bar is not exposed in the API
|
100 |
+
init_progress_hook()
|
101 |
+
|
102 |
+
listeners = _get_thread_local_listeners()
|
103 |
+
listeners.append(progress_listener)
|
104 |
+
|
105 |
+
def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
|
106 |
+
listeners = _get_thread_local_listeners()
|
107 |
+
|
108 |
+
if progress_listener in listeners:
|
109 |
+
listeners.remove(progress_listener)
|
110 |
+
|
111 |
+
def create_progress_listener_handle(progress_listener: ProgressListener):
|
112 |
+
return ProgressListenerHandle(progress_listener)
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
with create_progress_listener_handle(ProgressListener()) as listener:
|
116 |
+
# Call model.transcribe here
|
117 |
+
pass
|
118 |
+
|
119 |
+
print("Done")
|
src/vad.py
CHANGED
@@ -5,6 +5,7 @@ import time
|
|
5 |
from typing import Any, Deque, Iterator, List, Dict
|
6 |
|
7 |
from pprint import pprint
|
|
|
8 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
9 |
|
10 |
from src.segments import merge_timestamps
|
@@ -135,7 +136,8 @@ class AbstractTranscription(ABC):
|
|
135 |
pprint(merged)
|
136 |
return merged
|
137 |
|
138 |
-
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig
|
|
|
139 |
"""
|
140 |
Transcribe the given audo file.
|
141 |
|
@@ -184,7 +186,7 @@ class AbstractTranscription(ABC):
|
|
184 |
segment_duration = segment_end - segment_start
|
185 |
|
186 |
if segment_duration < MIN_SEGMENT_DURATION:
|
187 |
-
continue
|
188 |
|
189 |
# Audio to run on Whisper
|
190 |
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
@@ -196,7 +198,9 @@ class AbstractTranscription(ABC):
|
|
196 |
|
197 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
198 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
199 |
-
|
|
|
|
|
200 |
|
201 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
202 |
|
@@ -226,7 +230,7 @@ class AbstractTranscription(ABC):
|
|
226 |
result['language'] = detected_language
|
227 |
|
228 |
return result
|
229 |
-
|
230 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
231 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
232 |
# Add segments to the current prompt window (unless it is a speech gap)
|
|
|
5 |
from typing import Any, Deque, Iterator, List, Dict
|
6 |
|
7 |
from pprint import pprint
|
8 |
+
from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
|
9 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
10 |
|
11 |
from src.segments import merge_timestamps
|
|
|
136 |
pprint(merged)
|
137 |
return merged
|
138 |
|
139 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
140 |
+
progressListener: ProgressListener = None):
|
141 |
"""
|
142 |
Transcribe the given audo file.
|
143 |
|
|
|
186 |
segment_duration = segment_end - segment_start
|
187 |
|
188 |
if segment_duration < MIN_SEGMENT_DURATION:
|
189 |
+
continue
|
190 |
|
191 |
# Audio to run on Whisper
|
192 |
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
|
|
198 |
|
199 |
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
200 |
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
201 |
+
|
202 |
+
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=max_audio_duration, sub_task_start=segment_start, sub_task_total=segment_duration)
|
203 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
204 |
|
205 |
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
206 |
|
|
|
230 |
result['language'] = detected_language
|
231 |
|
232 |
return result
|
233 |
+
|
234 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
235 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
236 |
# Add segments to the current prompt window (unless it is a speech gap)
|
src/whisperContainer.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1 |
# External programs
|
2 |
import os
|
|
|
3 |
from typing import List
|
|
|
4 |
import whisper
|
|
|
|
|
5 |
from src.config import ModelConfig
|
|
|
6 |
|
7 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
8 |
|
@@ -116,7 +121,7 @@ class WhisperCallback:
|
|
116 |
self.initial_prompt = initial_prompt
|
117 |
self.decodeOptions = decodeOptions
|
118 |
|
119 |
-
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
|
120 |
"""
|
121 |
Peform the transcription of the given audio file or data.
|
122 |
|
@@ -139,10 +144,18 @@ class WhisperCallback:
|
|
139 |
"""
|
140 |
model = self.model_container.get_model()
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
return model.transcribe(audio, \
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
146 |
|
147 |
def _concat_prompt(self, prompt1, prompt2):
|
148 |
if (prompt1 is None):
|
|
|
1 |
# External programs
|
2 |
import os
|
3 |
+
import sys
|
4 |
from typing import List
|
5 |
+
|
6 |
import whisper
|
7 |
+
from whisper import Whisper
|
8 |
+
|
9 |
from src.config import ModelConfig
|
10 |
+
from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
|
11 |
|
12 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
13 |
|
|
|
121 |
self.initial_prompt = initial_prompt
|
122 |
self.decodeOptions = decodeOptions
|
123 |
|
124 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
125 |
"""
|
126 |
Peform the transcription of the given audio file or data.
|
127 |
|
|
|
144 |
"""
|
145 |
model = self.model_container.get_model()
|
146 |
|
147 |
+
if progress_listener is not None:
|
148 |
+
with create_progress_listener_handle(progress_listener):
|
149 |
+
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
150 |
+
else:
|
151 |
+
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
152 |
+
|
153 |
+
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
|
154 |
return model.transcribe(audio, \
|
155 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
156 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
157 |
+
**self.decodeOptions
|
158 |
+
)
|
159 |
|
160 |
def _concat_prompt(self, prompt1, prompt2):
|
161 |
if (prompt1 is None):
|