aadnk commited on
Commit
33a2c1e
1 Parent(s): 418fd6a

Add progress listener to none/VAD

Browse files

Note that we don't handle progress of parallel transcription yet.

Files changed (4) hide show
  1. app.py +46 -16
  2. src/hooks/whisperProgressHook.py +119 -0
  3. src/vad.py +8 -4
  4. src/whisperContainer.py +17 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from datetime import datetime
2
  import math
3
- from typing import Iterator
4
  import argparse
5
 
6
  from io import StringIO
@@ -12,6 +12,7 @@ import numpy as np
12
 
13
  import torch
14
  from src.config import ApplicationConfig
 
15
  from src.modelCache import ModelCache
16
  from src.source import get_audio_source_collection
17
  from src.vadParallel import ParallelContext, ParallelTranscription
@@ -87,14 +88,17 @@ class WhisperTranscriber:
87
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
88
 
89
  # Entry function for the simple tab
90
- def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
91
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
 
92
 
93
  # Entry function for the full tab
94
  def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
95
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
96
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
97
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
98
 
99
  # Handle temperature_increment_on_fallback
100
  if temperature_increment_on_fallback is not None:
@@ -105,9 +109,11 @@ class WhisperTranscriber:
105
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
106
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
107
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
108
- compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold)
 
109
 
110
- def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions: dict):
 
111
  try:
112
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
113
 
@@ -140,7 +146,7 @@ class WhisperTranscriber:
140
  print("Transcribing ", source.source_path)
141
 
142
  # Transcribe
143
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions)
144
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
145
 
146
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
@@ -202,7 +208,8 @@ class WhisperTranscriber:
202
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
203
 
204
  def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
205
- vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
 
206
 
207
  initial_prompt = decodeOptions.pop('initial_prompt', None)
208
 
@@ -212,25 +219,28 @@ class WhisperTranscriber:
212
  # Callable for processing an audio file
213
  whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
214
 
 
 
 
215
  # The results
216
  if (vad == 'silero-vad'):
217
  # Silero VAD where non-speech gaps are transcribed
218
  process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
219
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
220
  elif (vad == 'silero-vad-skip-gaps'):
221
  # Silero VAD where non-speech gaps are simply ignored
222
  skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
223
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
224
  elif (vad == 'silero-vad-expand-into-gaps'):
225
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
226
  expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
227
- result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
228
  elif (vad == 'periodic-vad'):
229
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
230
  # it may create a break in the middle of a sentence, causing some artifacts.
231
  periodic_vad = VadPeriodicTranscription()
232
  period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
233
- result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
234
 
235
  else:
236
  if (self._has_parallel_devices()):
@@ -238,18 +248,38 @@ class WhisperTranscriber:
238
  periodic_vad = VadPeriodicTranscription()
239
  period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
240
 
241
- result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
242
  else:
243
  # Default VAD
244
- result = whisperCallable.invoke(audio_path, 0, None, None)
245
 
246
  return result
247
 
248
- def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  if (not self._has_parallel_devices()):
250
  # No parallel devices, so just run the VAD and Whisper in sequence
251
- return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
252
 
 
253
  gpu_devices = self.parallel_device_list
254
 
255
  if (gpu_devices is None or len(gpu_devices) == 0):
 
1
  from datetime import datetime
2
  import math
3
+ from typing import Iterator, Union
4
  import argparse
5
 
6
  from io import StringIO
 
12
 
13
  import torch
14
  from src.config import ApplicationConfig
15
+ from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
16
  from src.modelCache import ModelCache
17
  from src.source import get_audio_source_collection
18
  from src.vadParallel import ParallelContext, ParallelTranscription
 
88
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
89
 
90
  # Entry function for the simple tab
91
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
92
+ progress=gr.Progress()):
93
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
94
+ progress=progress)
95
 
96
  # Entry function for the full tab
97
  def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
98
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
99
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
100
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
101
+ progress=gr.Progress()):
102
 
103
  # Handle temperature_increment_on_fallback
104
  if temperature_increment_on_fallback is not None:
 
109
  return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
110
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
111
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
112
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
113
+ progress=progress)
114
 
115
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
116
+ progress: gr.Progress = None, **decodeOptions: dict):
117
  try:
118
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
119
 
 
146
  print("Transcribing ", source.source_path)
147
 
148
  # Transcribe
149
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, progress, **decodeOptions)
150
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
151
 
152
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
 
208
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
209
 
210
  def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
211
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
212
+ progress: gr.Progress = None, **decodeOptions: dict):
213
 
214
  initial_prompt = decodeOptions.pop('initial_prompt', None)
215
 
 
219
  # Callable for processing an audio file
220
  whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
221
 
222
+ # A listener that will report progress to Gradio
223
+ progressListener = self._create_progress_listener(progress)
224
+
225
  # The results
226
  if (vad == 'silero-vad'):
227
  # Silero VAD where non-speech gaps are transcribed
228
  process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
229
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
230
  elif (vad == 'silero-vad-skip-gaps'):
231
  # Silero VAD where non-speech gaps are simply ignored
232
  skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
233
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
234
  elif (vad == 'silero-vad-expand-into-gaps'):
235
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
236
  expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
237
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
238
  elif (vad == 'periodic-vad'):
239
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
240
  # it may create a break in the middle of a sentence, causing some artifacts.
241
  periodic_vad = VadPeriodicTranscription()
242
  period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
243
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
244
 
245
  else:
246
  if (self._has_parallel_devices()):
 
248
  periodic_vad = VadPeriodicTranscription()
249
  period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
250
 
251
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
252
  else:
253
  # Default VAD
254
+ result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
255
 
256
  return result
257
 
258
+ def _create_progress_listener(self, progress: gr.Progress):
259
+ if (progress is None):
260
+ # Dummy progress listener
261
+ return ProgressListener()
262
+
263
+ class ForwardingProgressListener(ProgressListener):
264
+ def __init__(self, progress: gr.Progress):
265
+ self.progress = progress
266
+
267
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
268
+ # From 0 to 1
269
+ self.progress(current / total)
270
+
271
+ def on_finished(self):
272
+ self.progress(1)
273
+
274
+ return ForwardingProgressListener(progress)
275
+
276
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
277
+ progressListener: ProgressListener = None):
278
  if (not self._has_parallel_devices()):
279
  # No parallel devices, so just run the VAD and Whisper in sequence
280
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
281
 
282
+ # TODO: Handle progress listener
283
  gpu_devices = self.parallel_device_list
284
 
285
  if (gpu_devices is None or len(gpu_devices) == 0):
src/hooks/whisperProgressHook.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import threading
3
+ from typing import List, Union
4
+ import tqdm
5
+
6
+ class ProgressListener:
7
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
8
+ self.total = total
9
+
10
+ def on_finished(self):
11
+ pass
12
+
13
+ class ProgressListenerHandle:
14
+ def __init__(self, listener: ProgressListener):
15
+ self.listener = listener
16
+
17
+ def __enter__(self):
18
+ register_thread_local_progress_listener(self.listener)
19
+
20
+ def __exit__(self, exc_type, exc_val, exc_tb):
21
+ unregister_thread_local_progress_listener(self.listener)
22
+
23
+ if exc_type is None:
24
+ self.listener.on_finished()
25
+
26
+ class SubTaskProgressListener(ProgressListener):
27
+ """
28
+ A sub task listener that reports the progress of a sub task to a base task listener
29
+
30
+ Parameters
31
+ ----------
32
+ base_task_listener : ProgressListener
33
+ The base progress listener to accumulate overall progress in.
34
+ base_task_total : float
35
+ The maximum total progress that will be reported to the base progress listener.
36
+ sub_task_start : float
37
+ The starting progress of a sub task, in respect to the base progress listener.
38
+ sub_task_total : float
39
+ The total amount of progress a sub task will report to the base progress listener.
40
+ """
41
+ def __init__(
42
+ self,
43
+ base_task_listener: ProgressListener,
44
+ base_task_total: float,
45
+ sub_task_start: float,
46
+ sub_task_total: float,
47
+ ):
48
+ self.base_task_listener = base_task_listener
49
+ self.base_task_total = base_task_total
50
+ self.sub_task_start = sub_task_start
51
+ self.sub_task_total = sub_task_total
52
+
53
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
54
+ sub_task_progress_frac = current / total
55
+ sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
56
+ self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
57
+
58
+ def on_finished(self):
59
+ self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
60
+
61
+ class _CustomProgressBar(tqdm.tqdm):
62
+ def __init__(self, *args, **kwargs):
63
+ super().__init__(*args, **kwargs)
64
+ self._current = self.n # Set the initial value
65
+
66
+ def update(self, n):
67
+ super().update(n)
68
+ # Because the progress bar might be disabled, we need to manually update the progress
69
+ self._current += n
70
+
71
+ # Inform listeners
72
+ listeners = _get_thread_local_listeners()
73
+
74
+ for listener in listeners:
75
+ listener.on_progress(self._current, self.total)
76
+
77
+ _thread_local = threading.local()
78
+
79
+ def _get_thread_local_listeners():
80
+ if not hasattr(_thread_local, 'listeners'):
81
+ _thread_local.listeners = []
82
+ return _thread_local.listeners
83
+
84
+ _hooked = False
85
+
86
+ def init_progress_hook():
87
+ global _hooked
88
+
89
+ if _hooked:
90
+ return
91
+
92
+ # Inject into tqdm.tqdm of Whisper, so we can see progress
93
+ import whisper.transcribe
94
+ transcribe_module = sys.modules['whisper.transcribe']
95
+ transcribe_module.tqdm.tqdm = _CustomProgressBar
96
+ _hooked = True
97
+
98
+ def register_thread_local_progress_listener(progress_listener: ProgressListener):
99
+ # This is a workaround for the fact that the progress bar is not exposed in the API
100
+ init_progress_hook()
101
+
102
+ listeners = _get_thread_local_listeners()
103
+ listeners.append(progress_listener)
104
+
105
+ def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
106
+ listeners = _get_thread_local_listeners()
107
+
108
+ if progress_listener in listeners:
109
+ listeners.remove(progress_listener)
110
+
111
+ def create_progress_listener_handle(progress_listener: ProgressListener):
112
+ return ProgressListenerHandle(progress_listener)
113
+
114
+ if __name__ == '__main__':
115
+ with create_progress_listener_handle(ProgressListener()) as listener:
116
+ # Call model.transcribe here
117
+ pass
118
+
119
+ print("Done")
src/vad.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from typing import Any, Deque, Iterator, List, Dict
6
 
7
  from pprint import pprint
 
8
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
9
 
10
  from src.segments import merge_timestamps
@@ -135,7 +136,8 @@ class AbstractTranscription(ABC):
135
  pprint(merged)
136
  return merged
137
 
138
- def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
 
139
  """
140
  Transcribe the given audo file.
141
 
@@ -184,7 +186,7 @@ class AbstractTranscription(ABC):
184
  segment_duration = segment_end - segment_start
185
 
186
  if segment_duration < MIN_SEGMENT_DURATION:
187
- continue;
188
 
189
  # Audio to run on Whisper
190
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
@@ -196,7 +198,9 @@ class AbstractTranscription(ABC):
196
 
197
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
198
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
199
- segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
 
 
200
 
201
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
202
 
@@ -226,7 +230,7 @@ class AbstractTranscription(ABC):
226
  result['language'] = detected_language
227
 
228
  return result
229
-
230
  def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
231
  if (config.max_prompt_window is not None and config.max_prompt_window > 0):
232
  # Add segments to the current prompt window (unless it is a speech gap)
 
5
  from typing import Any, Deque, Iterator, List, Dict
6
 
7
  from pprint import pprint
8
+ from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
9
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
10
 
11
  from src.segments import merge_timestamps
 
136
  pprint(merged)
137
  return merged
138
 
139
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
140
+ progressListener: ProgressListener = None):
141
  """
142
  Transcribe the given audo file.
143
 
 
186
  segment_duration = segment_end - segment_start
187
 
188
  if segment_duration < MIN_SEGMENT_DURATION:
189
+ continue
190
 
191
  # Audio to run on Whisper
192
  segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
 
198
 
199
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
200
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
201
+
202
+ scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=max_audio_duration, sub_task_start=segment_start, sub_task_total=segment_duration)
203
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
204
 
205
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
206
 
 
230
  result['language'] = detected_language
231
 
232
  return result
233
+
234
  def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
235
  if (config.max_prompt_window is not None and config.max_prompt_window > 0):
236
  # Add segments to the current prompt window (unless it is a speech gap)
src/whisperContainer.py CHANGED
@@ -1,8 +1,13 @@
1
  # External programs
2
  import os
 
3
  from typing import List
 
4
  import whisper
 
 
5
  from src.config import ModelConfig
 
6
 
7
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
8
 
@@ -116,7 +121,7 @@ class WhisperCallback:
116
  self.initial_prompt = initial_prompt
117
  self.decodeOptions = decodeOptions
118
 
119
- def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
120
  """
121
  Peform the transcription of the given audio file or data.
122
 
@@ -139,10 +144,18 @@ class WhisperCallback:
139
  """
140
  model = self.model_container.get_model()
141
 
 
 
 
 
 
 
 
142
  return model.transcribe(audio, \
143
- language=self.language if self.language else detected_language, task=self.task, \
144
- initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
145
- **self.decodeOptions)
 
146
 
147
  def _concat_prompt(self, prompt1, prompt2):
148
  if (prompt1 is None):
 
1
  # External programs
2
  import os
3
+ import sys
4
  from typing import List
5
+
6
  import whisper
7
+ from whisper import Whisper
8
+
9
  from src.config import ModelConfig
10
+ from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
11
 
12
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
13
 
 
121
  self.initial_prompt = initial_prompt
122
  self.decodeOptions = decodeOptions
123
 
124
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
125
  """
126
  Peform the transcription of the given audio file or data.
127
 
 
144
  """
145
  model = self.model_container.get_model()
146
 
147
+ if progress_listener is not None:
148
+ with create_progress_listener_handle(progress_listener):
149
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
150
+ else:
151
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
152
+
153
+ def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
154
  return model.transcribe(audio, \
155
+ language=self.language if self.language else detected_language, task=self.task, \
156
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
157
+ **self.decodeOptions
158
+ )
159
 
160
  def _concat_prompt(self, prompt1, prompt2):
161
  if (prompt1 is None):