Anustup commited on
Commit
225975b
1 Parent(s): 9e921ca

Upload 4 files

Browse files
Files changed (4) hide show
  1. src/download.py +72 -0
  2. src/segments.py +55 -0
  3. src/utils.py +115 -0
  4. src/vad.py +477 -0
src/download.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ duration = info['duration']
50
+
51
+ if duration >= maxDuration:
52
+ raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
53
+
54
+ ydl.add_post_processor(filename_collector)
55
+ ydl.download([url])
56
+
57
+ if len(filename_collector.filenames) <= 0:
58
+ raise Exception("Cannot download " + url)
59
+
60
+ result = []
61
+
62
+ for filename in filename_collector.filenames:
63
+ result.append(filename)
64
+ print("Downloaded " + filename)
65
+
66
+ return result
67
+
68
+ class ExceededMaximumDuration(Exception):
69
+ def __init__(self, videoDuration, maxDuration, message):
70
+ self.videoDuration = videoDuration
71
+ self.maxDuration = maxDuration
72
+ super().__init__(message)
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
src/vad.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+
4
+ from typing import Any, Deque, Iterator, List, Dict
5
+
6
+ from pprint import pprint
7
+
8
+ from src.segments import merge_timestamps
9
+
10
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
11
+ try:
12
+ import tensorflow as tf
13
+ except ModuleNotFoundError:
14
+ # Error handling
15
+ pass
16
+
17
+ import torch
18
+
19
+ import ffmpeg
20
+ import numpy as np
21
+
22
+ from src.utils import format_timestamp
23
+ from enum import Enum
24
+
25
+ class NonSpeechStrategy(Enum):
26
+ """
27
+ Ignore non-speech frames segments.
28
+ """
29
+ SKIP = 1
30
+ """
31
+ Just treat non-speech segments as speech.
32
+ """
33
+ CREATE_SEGMENT = 2
34
+ """
35
+ Expand speech segments into subsequent non-speech segments.
36
+ """
37
+ EXPAND_SEGMENT = 3
38
+
39
+ # Defaults for Silero
40
+ SPEECH_TRESHOLD = 0.3
41
+
42
+ # Minimum size of segments to process
43
+ MIN_SEGMENT_DURATION = 1
44
+
45
+ # The maximum time for texts from old segments to be used in the next segment
46
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
47
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
48
+
49
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
50
+
51
+ class TranscriptionConfig(ABC):
52
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
53
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
54
+ max_merge_size: float = None, max_prompt_window: float = None):
55
+ self.non_speech_strategy = non_speech_strategy
56
+ self.segment_padding_left = segment_padding_left
57
+ self.segment_padding_right = segment_padding_right
58
+ self.max_silent_period = max_silent_period
59
+ self.max_merge_size = max_merge_size
60
+ self.max_prompt_window = max_prompt_window
61
+
62
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
63
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
64
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
65
+ max_merge_size: float = None, max_prompt_window: float = None):
66
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
67
+ self.periodic_duration = periodic_duration
68
+
69
+ class AbstractTranscription(ABC):
70
+ def __init__(self, sampling_rate: int = 16000):
71
+ self.sampling_rate = sampling_rate
72
+
73
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
74
+ return load_audio(str, self.sampling_rate, start_time, duration)
75
+
76
+ @abstractmethod
77
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
78
+ """
79
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
80
+
81
+ Parameters
82
+ ----------
83
+ audio: str
84
+ The audio file.
85
+ config: TranscriptionConfig
86
+ The transcription configuration.
87
+
88
+ Returns
89
+ -------
90
+ A list of start and end timestamps, in fractional seconds.
91
+ """
92
+ return
93
+
94
+ def transcribe(self, audio: str, whisperCallable, config: TranscriptionConfig):
95
+ """
96
+ Transcribe the given audo file.
97
+
98
+ Parameters
99
+ ----------
100
+ audio: str
101
+ The audio file.
102
+
103
+ whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], int, str, str], dict[str, Union[dict, Any]]]
104
+ The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
+ the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
+
107
+ Returns
108
+ -------
109
+ A list of start and end timestamps, in fractional seconds.
110
+ """
111
+
112
+ # get speech timestamps from full audio file
113
+ seconds_timestamps = self.get_transcribe_timestamps(audio, config)
114
+
115
+ #for seconds_timestamp in seconds_timestamps:
116
+ # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
+
118
+ merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
119
+
120
+ # A deque of transcribed segments that is passed to the next segment as a prompt
121
+ prompt_window = deque()
122
+
123
+ print("Timestamps:")
124
+ pprint(merged)
125
+
126
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
+ max_audio_duration = get_audio_duration(audio)
128
+
129
+ # Expand segments to include the gaps between them
130
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
131
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
132
+ merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=config.max_merge_size)
133
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
134
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
135
+ merged = self.expand_gaps(merged, total_duration=max_audio_duration)
136
+ else:
137
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
138
+
139
+ print("Transcribing non-speech:")
140
+ pprint(merged)
141
+
142
+ result = {
143
+ 'text': "",
144
+ 'segments': [],
145
+ 'language': ""
146
+ }
147
+ languageCounter = Counter()
148
+ detected_language = None
149
+
150
+ segment_index = -1
151
+
152
+ # For each time segment, run whisper
153
+ for segment in merged:
154
+ segment_index += 1
155
+ segment_start = segment['start']
156
+ segment_end = segment['end']
157
+ segment_expand_amount = segment.get('expand_amount', 0)
158
+ segment_gap = segment.get('gap', False)
159
+
160
+ segment_duration = segment_end - segment_start
161
+
162
+ if segment_duration < MIN_SEGMENT_DURATION:
163
+ continue;
164
+
165
+ # Audio to run on Whisper
166
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
167
+ # Previous segments to use as a prompt
168
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
169
+
170
+ # Detected language
171
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
172
+
173
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
174
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
175
+ segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
176
+
177
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
178
+
179
+ # Propagate expand amount to the segments
180
+ if (segment_expand_amount > 0):
181
+ segment_without_expansion = segment_duration - segment_expand_amount
182
+
183
+ for adjusted_segment in adjusted_segments:
184
+ adjusted_segment_end = adjusted_segment['end']
185
+
186
+ # Add expand amount if the segment got expanded
187
+ if (adjusted_segment_end > segment_without_expansion):
188
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
189
+
190
+ # Append to output
191
+ result['text'] += segment_result['text']
192
+ result['segments'].extend(adjusted_segments)
193
+
194
+ # Increment detected language
195
+ if not segment_gap:
196
+ languageCounter[segment_result['language']] += 1
197
+
198
+ # Update prompt window
199
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
200
+
201
+ if detected_language is not None:
202
+ result['language'] = detected_language
203
+
204
+ return result
205
+
206
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
207
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
208
+ # Add segments to the current prompt window (unless it is a speech gap)
209
+ if not segment_gap:
210
+ for segment in adjusted_segments:
211
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
212
+ prompt_window.append(segment)
213
+
214
+ while (len(prompt_window) > 0):
215
+ first_end_time = prompt_window[0].get('end', 0)
216
+ # Time expanded in the segments should be discounted from the prompt window
217
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
218
+
219
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
220
+ prompt_window.popleft()
221
+ else:
222
+ break
223
+
224
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
225
+ result = []
226
+ last_end_time = 0
227
+
228
+ for segment in segments:
229
+ segment_start = float(segment['start'])
230
+ segment_end = float(segment['end'])
231
+
232
+ if (last_end_time != segment_start):
233
+ delta = segment_start - last_end_time
234
+
235
+ if (min_gap_length is None or delta >= min_gap_length):
236
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
237
+
238
+ last_end_time = segment_end
239
+ result.append(segment)
240
+
241
+ # Also include total duration if specified
242
+ if (total_duration is not None and last_end_time < total_duration):
243
+ delta = total_duration - segment_start
244
+
245
+ if (min_gap_length is None or delta >= min_gap_length):
246
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
247
+
248
+ return result
249
+
250
+ # Expand the end time of each segment to the start of the next segment
251
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
252
+ result = []
253
+
254
+ if len(segments) == 0:
255
+ return result
256
+
257
+ # Add gap at the beginning if needed
258
+ if (segments[0]['start'] > 0):
259
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
260
+
261
+ for i in range(len(segments) - 1):
262
+ current_segment = segments[i]
263
+ next_segment = segments[i + 1]
264
+
265
+ delta = next_segment['start'] - current_segment['end']
266
+
267
+ # Expand if the gap actually exists
268
+ if (delta >= 0):
269
+ current_segment = current_segment.copy()
270
+ current_segment['expand_amount'] = delta
271
+ current_segment['end'] = next_segment['start']
272
+
273
+ result.append(current_segment)
274
+
275
+ # Add last segment
276
+ last_segment = segments[-1]
277
+ result.append(last_segment)
278
+
279
+ # Also include total duration if specified
280
+ if (total_duration is not None):
281
+ last_segment = result[-1]
282
+
283
+ if (last_segment['end'] < total_duration):
284
+ last_segment = last_segment.copy()
285
+ last_segment['end'] = total_duration
286
+ result[-1] = last_segment
287
+
288
+ return result
289
+
290
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
291
+ result = []
292
+
293
+ if len(segments) == 0:
294
+ return result
295
+
296
+ # Add gap at the beginning if needed
297
+ if (segments[0]['start'] > 0):
298
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
299
+
300
+ for i in range(len(segments) - 1):
301
+ expanded = False
302
+ current_segment = segments[i]
303
+ next_segment = segments[i + 1]
304
+
305
+ delta = next_segment['start'] - current_segment['end']
306
+
307
+ if (max_expand_size is not None and delta <= max_expand_size):
308
+ # Just expand the current segment
309
+ current_segment = current_segment.copy()
310
+ current_segment['expand_amount'] = delta
311
+ current_segment['end'] = next_segment['start']
312
+ expanded = True
313
+
314
+ result.append(current_segment)
315
+
316
+ # Add a gap to the next segment if needed
317
+ if (delta >= 0 and not expanded):
318
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
319
+
320
+ # Add last segment
321
+ last_segment = segments[-1]
322
+ result.append(last_segment)
323
+
324
+ # Also include total duration if specified
325
+ if (total_duration is not None):
326
+ last_segment = result[-1]
327
+
328
+ delta = total_duration - last_segment['end']
329
+
330
+ if (delta > 0):
331
+ if (max_expand_size is not None and delta <= max_expand_size):
332
+ # Expand the last segment
333
+ last_segment = last_segment.copy()
334
+ last_segment['expand_amount'] = delta
335
+ last_segment['end'] = total_duration
336
+ result[-1] = last_segment
337
+ else:
338
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
339
+
340
+ return result
341
+
342
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
343
+ result = []
344
+
345
+ for segment in segments:
346
+ segment_start = float(segment['start'])
347
+ segment_end = float(segment['end'])
348
+
349
+ # Filter segments?
350
+ if (max_source_time is not None):
351
+ if (segment_start > max_source_time):
352
+ continue
353
+ segment_end = min(max_source_time, segment_end)
354
+
355
+ new_segment = segment.copy()
356
+
357
+ # Add to start and end
358
+ new_segment['start'] = segment_start + adjust_seconds
359
+ new_segment['end'] = segment_end + adjust_seconds
360
+ result.append(new_segment)
361
+ return result
362
+
363
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
364
+ result = []
365
+
366
+ for entry in timestamps:
367
+ start = entry['start']
368
+ end = entry['end']
369
+
370
+ result.append({
371
+ 'start': start * factor,
372
+ 'end': end * factor
373
+ })
374
+ return result
375
+
376
+ class VadSileroTranscription(AbstractTranscription):
377
+ def __init__(self, sampling_rate: int = 16000):
378
+ super().__init__(sampling_rate=sampling_rate)
379
+
380
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
381
+ (self.get_speech_timestamps, _, _, _, _) = utils
382
+
383
+
384
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
385
+ audio_duration = get_audio_duration(audio)
386
+ result = []
387
+
388
+ # Divide procesisng of audio into chunks
389
+ chunk_start = 0.0
390
+
391
+ while (chunk_start < audio_duration):
392
+ chunk_duration = min(audio_duration - chunk_start, VAD_MAX_PROCESSING_CHUNK)
393
+
394
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
395
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
396
+
397
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
398
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
399
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
400
+
401
+ #pprint(adjusted)
402
+
403
+ result.extend(adjusted)
404
+ chunk_start += chunk_duration
405
+
406
+ return result
407
+
408
+ # A very simple VAD that just marks every N seconds as speech
409
+ class VadPeriodicTranscription(AbstractTranscription):
410
+ def __init__(self, sampling_rate: int = 16000):
411
+ super().__init__(sampling_rate=sampling_rate)
412
+
413
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
414
+ # Get duration in seconds
415
+ audio_duration = get_audio_duration(audio)
416
+ result = []
417
+
418
+ # Generate a timestamp every N seconds
419
+ start_timestamp = 0
420
+
421
+ while (start_timestamp < audio_duration):
422
+ end_timestamp = min(start_timestamp + config.periodic_duration, audio_duration)
423
+ segment_duration = end_timestamp - start_timestamp
424
+
425
+ # Minimum duration is 1 second
426
+ if (segment_duration >= 1):
427
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
428
+
429
+ start_timestamp = end_timestamp
430
+
431
+ return result
432
+
433
+ def get_audio_duration(file: str):
434
+ return float(ffmpeg.probe(file)["format"]["duration"])
435
+
436
+ def load_audio(file: str, sample_rate: int = 16000,
437
+ start_time: str = None, duration: str = None):
438
+ """
439
+ Open an audio file and read as mono waveform, resampling as necessary
440
+
441
+ Parameters
442
+ ----------
443
+ file: str
444
+ The audio file to open
445
+
446
+ sr: int
447
+ The sample rate to resample the audio if necessary
448
+
449
+ start_time: str
450
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
451
+
452
+ duration: str
453
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
454
+
455
+ Returns
456
+ -------
457
+ A NumPy array containing the audio waveform, in float32 dtype.
458
+ """
459
+ try:
460
+ inputArgs = {'threads': 0}
461
+
462
+ if (start_time is not None):
463
+ inputArgs['ss'] = start_time
464
+ if (duration is not None):
465
+ inputArgs['t'] = duration
466
+
467
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
468
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
469
+ out, _ = (
470
+ ffmpeg.input(file, **inputArgs)
471
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
472
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
473
+ )
474
+ except ffmpeg.Error as e:
475
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
476
+
477
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0