Spaces:
Sleeping
Sleeping
Support progress for multiple devices
Browse files- app.py +2 -2
- src/vad.py +78 -66
- src/vadParallel.py +50 -8
app.py
CHANGED
@@ -279,7 +279,6 @@ class WhisperTranscriber:
|
|
279 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
280 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
281 |
|
282 |
-
# TODO: Handle progress listener
|
283 |
gpu_devices = self.parallel_device_list
|
284 |
|
285 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
@@ -297,7 +296,8 @@ class WhisperTranscriber:
|
|
297 |
parallel_vad = ParallelTranscription()
|
298 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
299 |
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
300 |
-
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context
|
|
|
301 |
|
302 |
def _has_parallel_devices(self):
|
303 |
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
|
|
279 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
280 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
281 |
|
|
|
282 |
gpu_devices = self.parallel_device_list
|
283 |
|
284 |
if (gpu_devices is None or len(gpu_devices) == 0):
|
|
|
296 |
parallel_vad = ParallelTranscription()
|
297 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
298 |
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
299 |
+
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
|
300 |
+
progress_listener=progressListener)
|
301 |
|
302 |
def _has_parallel_devices(self):
|
303 |
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
src/vad.py
CHANGED
@@ -153,84 +153,96 @@ class AbstractTranscription(ABC):
|
|
153 |
A list of start and end timestamps, in fractional seconds.
|
154 |
"""
|
155 |
|
156 |
-
|
157 |
-
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
result = {
|
169 |
-
'text': "",
|
170 |
-
'segments': [],
|
171 |
-
'language': ""
|
172 |
-
}
|
173 |
-
languageCounter = Counter()
|
174 |
-
detected_language = None
|
175 |
-
|
176 |
-
segment_index = config.initial_segment_index
|
177 |
-
|
178 |
-
# For each time segment, run whisper
|
179 |
-
for segment in merged:
|
180 |
-
segment_index += 1
|
181 |
-
segment_start = segment['start']
|
182 |
-
segment_end = segment['end']
|
183 |
-
segment_expand_amount = segment.get('expand_amount', 0)
|
184 |
-
segment_gap = segment.get('gap', False)
|
185 |
-
|
186 |
-
segment_duration = segment_end - segment_start
|
187 |
-
|
188 |
-
if segment_duration < MIN_SEGMENT_DURATION:
|
189 |
-
continue
|
190 |
-
|
191 |
-
# Audio to run on Whisper
|
192 |
-
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
193 |
-
# Previous segments to use as a prompt
|
194 |
-
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
195 |
-
|
196 |
-
# Detected language
|
197 |
-
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
198 |
-
|
199 |
-
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
200 |
-
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
201 |
-
|
202 |
-
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=max_audio_duration, sub_task_start=segment_start, sub_task_total=segment_duration)
|
203 |
-
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
204 |
-
|
205 |
-
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
adjusted_segment_end = adjusted_segment['end']
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
|
218 |
-
#
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
|
223 |
-
if not segment_gap:
|
224 |
-
languageCounter[segment_result['language']] += 1
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
if detected_language is not None:
|
230 |
-
result['language'] = detected_language
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
return result
|
233 |
|
|
|
|
|
|
|
234 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
235 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
236 |
# Add segments to the current prompt window (unless it is a speech gap)
|
|
|
153 |
A list of start and end timestamps, in fractional seconds.
|
154 |
"""
|
155 |
|
156 |
+
try:
|
157 |
+
max_audio_duration = self.get_audio_duration(audio, config)
|
158 |
+
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
|
159 |
|
160 |
+
# Get speech timestamps from full audio file
|
161 |
+
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
|
162 |
|
163 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
164 |
+
prompt_window = deque()
|
165 |
|
166 |
+
print("Processing timestamps:")
|
167 |
+
pprint(merged)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
result = {
|
170 |
+
'text': "",
|
171 |
+
'segments': [],
|
172 |
+
'language': ""
|
173 |
+
}
|
174 |
+
languageCounter = Counter()
|
175 |
+
detected_language = None
|
176 |
|
177 |
+
segment_index = config.initial_segment_index
|
|
|
178 |
|
179 |
+
# Calculate progress
|
180 |
+
progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
|
181 |
+
progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
|
182 |
|
183 |
+
# For each time segment, run whisper
|
184 |
+
for segment in merged:
|
185 |
+
segment_index += 1
|
186 |
+
segment_start = segment['start']
|
187 |
+
segment_end = segment['end']
|
188 |
+
segment_expand_amount = segment.get('expand_amount', 0)
|
189 |
+
segment_gap = segment.get('gap', False)
|
190 |
|
191 |
+
segment_duration = segment_end - segment_start
|
|
|
|
|
192 |
|
193 |
+
if segment_duration < MIN_SEGMENT_DURATION:
|
194 |
+
continue
|
|
|
|
|
|
|
195 |
|
196 |
+
# Audio to run on Whisper
|
197 |
+
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
198 |
+
# Previous segments to use as a prompt
|
199 |
+
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
200 |
+
|
201 |
+
# Detected language
|
202 |
+
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
203 |
+
|
204 |
+
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
205 |
+
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
206 |
+
|
207 |
+
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
|
208 |
+
sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
|
209 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
210 |
+
|
211 |
+
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
212 |
+
|
213 |
+
# Propagate expand amount to the segments
|
214 |
+
if (segment_expand_amount > 0):
|
215 |
+
segment_without_expansion = segment_duration - segment_expand_amount
|
216 |
+
|
217 |
+
for adjusted_segment in adjusted_segments:
|
218 |
+
adjusted_segment_end = adjusted_segment['end']
|
219 |
+
|
220 |
+
# Add expand amount if the segment got expanded
|
221 |
+
if (adjusted_segment_end > segment_without_expansion):
|
222 |
+
adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
|
223 |
+
|
224 |
+
# Append to output
|
225 |
+
result['text'] += segment_result['text']
|
226 |
+
result['segments'].extend(adjusted_segments)
|
227 |
+
|
228 |
+
# Increment detected language
|
229 |
+
if not segment_gap:
|
230 |
+
languageCounter[segment_result['language']] += 1
|
231 |
+
|
232 |
+
# Update prompt window
|
233 |
+
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
234 |
+
|
235 |
+
if detected_language is not None:
|
236 |
+
result['language'] = detected_language
|
237 |
+
finally:
|
238 |
+
# Notify progress listener that we are done
|
239 |
+
if progressListener is not None:
|
240 |
+
progressListener.on_finished()
|
241 |
return result
|
242 |
|
243 |
+
def get_audio_duration(self, audio: str, config: TranscriptionConfig):
|
244 |
+
return get_audio_duration(audio)
|
245 |
+
|
246 |
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
247 |
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
248 |
# Add segments to the current prompt window (unless it is a speech gap)
|
src/vadParallel.py
CHANGED
@@ -1,14 +1,33 @@
|
|
1 |
import multiprocessing
|
|
|
2 |
import threading
|
3 |
import time
|
|
|
4 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
5 |
from src.whisperContainer import WhisperCallback
|
6 |
|
7 |
-
from multiprocessing import Pool
|
8 |
|
9 |
-
from typing import Any, Dict, List
|
10 |
import os
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class ParallelContext:
|
14 |
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
@@ -86,7 +105,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
86 |
super().__init__(sampling_rate=sampling_rate)
|
87 |
|
88 |
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
89 |
-
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None
|
|
|
90 |
total_duration = get_audio_duration(audio)
|
91 |
|
92 |
# First, get the timestamps for the original audio
|
@@ -108,6 +128,9 @@ class ParallelTranscription(AbstractTranscription):
|
|
108 |
parameters = []
|
109 |
segment_index = config.initial_segment_index
|
110 |
|
|
|
|
|
|
|
111 |
for i in range(len(gpu_devices)):
|
112 |
# Note that device_segment_list can be empty. But we will still create a process for it,
|
113 |
# as otherwise we run the risk of assigning the same device to multiple processes.
|
@@ -120,7 +143,8 @@ class ParallelTranscription(AbstractTranscription):
|
|
120 |
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
121 |
segment_index += len(device_segment_list)
|
122 |
|
123 |
-
|
|
|
124 |
|
125 |
merged = {
|
126 |
'text': '',
|
@@ -142,7 +166,24 @@ class ParallelTranscription(AbstractTranscription):
|
|
142 |
pool = gpu_parallel_context.get_pool()
|
143 |
|
144 |
# Run the transcription in parallel
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
for result in results:
|
148 |
# Merge the results
|
@@ -231,11 +272,12 @@ class ParallelTranscription(AbstractTranscription):
|
|
231 |
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
232 |
# Override timestamps that will be processed
|
233 |
if (config.override_timestamps is not None):
|
234 |
-
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
235 |
return config.override_timestamps
|
236 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
237 |
|
238 |
-
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig
|
|
|
239 |
# Override device ID the first time
|
240 |
if (os.environ.get("INITIALIZED", None) is None):
|
241 |
os.environ["INITIALIZED"] = "1"
|
@@ -246,7 +288,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
246 |
print("Using device " + config.device_id)
|
247 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
248 |
|
249 |
-
return super().transcribe(audio, whisperCallable, config)
|
250 |
|
251 |
def _split(self, a, n):
|
252 |
"""Split a list into n approximately equal parts."""
|
|
|
1 |
import multiprocessing
|
2 |
+
from queue import Empty
|
3 |
import threading
|
4 |
import time
|
5 |
+
from src.hooks.whisperProgressHook import ProgressListener
|
6 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
7 |
from src.whisperContainer import WhisperCallback
|
8 |
|
9 |
+
from multiprocessing import Pool, Queue
|
10 |
|
11 |
+
from typing import Any, Dict, List, Union
|
12 |
import os
|
13 |
|
14 |
+
class _ProgressListenerToQueue(ProgressListener):
|
15 |
+
def __init__(self, progress_queue: Queue):
|
16 |
+
self.progress_queue = progress_queue
|
17 |
+
self.progress_total = 0
|
18 |
+
self.prev_progress = 0
|
19 |
+
|
20 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
21 |
+
delta = current - self.prev_progress
|
22 |
+
self.prev_progress = current
|
23 |
+
self.progress_total = total
|
24 |
+
self.progress_queue.put(delta)
|
25 |
+
|
26 |
+
def on_finished(self):
|
27 |
+
if self.progress_total > self.prev_progress:
|
28 |
+
delta = self.progress_total - self.prev_progress
|
29 |
+
self.progress_queue.put(delta)
|
30 |
+
self.prev_progress = self.progress_total
|
31 |
|
32 |
class ParallelContext:
|
33 |
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
|
|
105 |
super().__init__(sampling_rate=sampling_rate)
|
106 |
|
107 |
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
108 |
+
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
|
109 |
+
progress_listener: ProgressListener = None):
|
110 |
total_duration = get_audio_duration(audio)
|
111 |
|
112 |
# First, get the timestamps for the original audio
|
|
|
128 |
parameters = []
|
129 |
segment_index = config.initial_segment_index
|
130 |
|
131 |
+
processing_manager = multiprocessing.Manager()
|
132 |
+
progress_queue = processing_manager.Queue()
|
133 |
+
|
134 |
for i in range(len(gpu_devices)):
|
135 |
# Note that device_segment_list can be empty. But we will still create a process for it,
|
136 |
# as otherwise we run the risk of assigning the same device to multiple processes.
|
|
|
143 |
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
144 |
segment_index += len(device_segment_list)
|
145 |
|
146 |
+
progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
|
147 |
+
parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
|
148 |
|
149 |
merged = {
|
150 |
'text': '',
|
|
|
166 |
pool = gpu_parallel_context.get_pool()
|
167 |
|
168 |
# Run the transcription in parallel
|
169 |
+
results_async = pool.starmap_async(self.transcribe, parameters)
|
170 |
+
total_progress = 0
|
171 |
+
|
172 |
+
while not results_async.ready():
|
173 |
+
try:
|
174 |
+
delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
|
175 |
+
except Empty:
|
176 |
+
continue
|
177 |
+
|
178 |
+
total_progress += delta
|
179 |
+
if progress_listener is not None:
|
180 |
+
progress_listener.on_progress(total_progress, total_duration)
|
181 |
+
|
182 |
+
results = results_async.get()
|
183 |
+
|
184 |
+
# Call the finished callback
|
185 |
+
if progress_listener is not None:
|
186 |
+
progress_listener.on_finished()
|
187 |
|
188 |
for result in results:
|
189 |
# Merge the results
|
|
|
272 |
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
273 |
# Override timestamps that will be processed
|
274 |
if (config.override_timestamps is not None):
|
275 |
+
print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
|
276 |
return config.override_timestamps
|
277 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
278 |
|
279 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig,
|
280 |
+
progressListener: ProgressListener = None):
|
281 |
# Override device ID the first time
|
282 |
if (os.environ.get("INITIALIZED", None) is None):
|
283 |
os.environ["INITIALIZED"] = "1"
|
|
|
288 |
print("Using device " + config.device_id)
|
289 |
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
290 |
|
291 |
+
return super().transcribe(audio, whisperCallable, config, progressListener)
|
292 |
|
293 |
def _split(self, a, n):
|
294 |
"""Split a list into n approximately equal parts."""
|