aadnk commited on
Commit
5bbbb16
1 Parent(s): f5884f3

Cleanup code

Browse files
Files changed (3) hide show
  1. app.py +13 -13
  2. src/segments.py +9 -1
  3. src/vad.py +42 -49
app.py CHANGED
@@ -14,7 +14,7 @@ import gradio as gr
14
 
15
  from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
- from src.vad import NonSpeechStrategy, VadPeriodicTranscription, VadSileroTranscription
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -96,38 +96,38 @@ class WhisperTranscriber:
96
  # The results
97
  if (vad == 'silero-vad'):
98
  # Silero VAD where non-speech gaps are transcribed
99
- process_gaps = self._create_silero_vad(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
100
- result = process_gaps.transcribe(audio_path, whisperCallable)
101
  elif (vad == 'silero-vad-skip-gaps'):
102
  # Silero VAD where non-speech gaps are simply ignored
103
- skip_gaps = self._create_silero_vad(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
104
- result = skip_gaps.transcribe(audio_path, whisperCallable)
105
  elif (vad == 'silero-vad-expand-into-gaps'):
106
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
107
- expand_gaps = self._create_silero_vad(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
108
- result = expand_gaps.transcribe(audio_path, whisperCallable)
109
  elif (vad == 'periodic-vad'):
110
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
111
  # it may create a break in the middle of a sentence, causing some artifacts.
112
- periodic_vad = VadPeriodicTranscription(periodic_duration=vadMaxMergeSize)
113
- result = periodic_vad.transcribe(audio_path, whisperCallable)
114
  else:
115
  # Default VAD
116
  result = whisperCallable(audio_path, None, None)
117
 
118
  return result
119
 
120
- def _create_silero_vad(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
121
  # Use Silero VAD
122
  if (self.vad_model is None):
123
  self.vad_model = VadSileroTranscription()
124
 
125
- result = VadSileroTranscription(non_speech_strategy = non_speech_strategy,
126
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
127
  segment_padding_left=vadPadding, segment_padding_right=vadPadding,
128
- max_prompt_window=vadPromptWindow, copy=self.vad_model)
129
 
130
- return result
131
 
132
  def write_result(self, result: dict, source_name: str, output_dir: str):
133
  if not os.path.exists(output_dir):
 
14
 
15
  from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
+ from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
 
96
  # The results
97
  if (vad == 'silero-vad'):
98
  # Silero VAD where non-speech gaps are transcribed
99
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
100
+ result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
101
  elif (vad == 'silero-vad-skip-gaps'):
102
  # Silero VAD where non-speech gaps are simply ignored
103
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
104
+ result = skip_gaps.transcribe(audio_path, whisperCallable, skip_gaps)
105
  elif (vad == 'silero-vad-expand-into-gaps'):
106
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
107
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
108
+ result = expand_gaps.transcribe(audio_path, whisperCallable, expand_gaps)
109
  elif (vad == 'periodic-vad'):
110
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
111
  # it may create a break in the middle of a sentence, causing some artifacts.
112
+ periodic_vad = VadPeriodicTranscription()
113
+ result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
114
  else:
115
  # Default VAD
116
  result = whisperCallable(audio_path, None, None)
117
 
118
  return result
119
 
120
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
121
  # Use Silero VAD
122
  if (self.vad_model is None):
123
  self.vad_model = VadSileroTranscription()
124
 
125
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
126
  max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
127
  segment_padding_left=vadPadding, segment_padding_right=vadPadding,
128
+ max_prompt_window=vadPromptWindow)
129
 
130
+ return config
131
 
132
  def write_result(self, result: dict, source_name: str, output_dir: str):
133
  if not os.path.exists(output_dir):
src/segments.py CHANGED
@@ -7,6 +7,13 @@ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5,
7
 
8
  if len(timestamps) == 0:
9
  return result
 
 
 
 
 
 
 
10
 
11
  processed_time = 0
12
  current_segment = None
@@ -17,7 +24,8 @@ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5,
17
  delta = next_segment['start'] - processed_time
18
 
19
  # Note that segments can still be longer than the max merge size, they just won't be merged in that case
20
- if current_segment is None or delta > merge_window or next_segment['end'] - current_segment['start'] > max_merge_size:
 
21
  # Finish the current segment
22
  if current_segment is not None:
23
  # Add right padding
 
7
 
8
  if len(timestamps) == 0:
9
  return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
 
18
  processed_time = 0
19
  current_segment = None
 
24
  delta = next_segment['start'] - processed_time
25
 
26
  # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
  # Finish the current segment
30
  if current_segment is not None:
31
  # Add right padding
src/vad.py CHANGED
@@ -38,45 +38,43 @@ class NonSpeechStrategy(Enum):
38
 
39
  # Defaults for Silero
40
  SPEECH_TRESHOLD = 0.3
41
- MAX_SILENT_PERIOD = 10 # seconds
42
- MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
43
-
44
- # Default segment padding
45
- SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
46
- SEGMENT_PADDING_RIGHT = 1 # End detected segments late
47
 
48
  # Minimum size of segments to process
49
  MIN_SEGMENT_DURATION = 1
50
 
51
- # Always merge segments that are less than this duration apart
52
- MIN_FORCE_MERGE_GAP = 0.5
53
- FORCE_MERGE_SEGMENT_MULTIPLIER = 1.5
54
-
55
  # The maximum time for texts from old segments to be used in the next segment
56
  MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
57
  PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
58
 
59
  VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
60
 
61
- class AbstractTranscription(ABC):
62
- def __init__(self, segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
63
- max_merge_size: float = None, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP, max_prompt_window: float = None):
64
- self.sampling_rate = 16000
 
65
  self.segment_padding_left = segment_padding_left
66
  self.segment_padding_right = segment_padding_right
67
  self.max_silent_period = max_silent_period
68
  self.max_merge_size = max_merge_size
69
- self.non_speech_strategy = non_speech_strategy
70
  self.max_prompt_window = max_prompt_window
71
 
72
- self.min_force_merge_gap = MIN_FORCE_MERGE_GAP
73
- self.max_force_merge_size = max_merge_size * FORCE_MERGE_SEGMENT_MULTIPLIER if max_merge_size is not None else None
 
 
 
 
 
 
 
 
74
 
75
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
76
  return load_audio(str, self.sampling_rate, start_time, duration)
77
 
78
  @abstractmethod
79
- def get_transcribe_timestamps(self, audio: str):
80
  """
81
  Get the start and end timestamps of the sections that should be transcribed by this VAD method.
82
 
@@ -84,6 +82,8 @@ class AbstractTranscription(ABC):
84
  ----------
85
  audio: str
86
  The audio file.
 
 
87
 
88
  Returns
89
  -------
@@ -91,7 +91,7 @@ class AbstractTranscription(ABC):
91
  """
92
  return
93
 
94
- def transcribe(self, audio: str, whisperCallable):
95
  """
96
  Transcribe the given audo file.
97
 
@@ -110,12 +110,12 @@ class AbstractTranscription(ABC):
110
  """
111
 
112
  # get speech timestamps from full audio file
113
- seconds_timestamps = self.get_transcribe_timestamps(audio)
114
 
115
  #for seconds_timestamp in seconds_timestamps:
116
  # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
 
118
- merged = merge_timestamps(seconds_timestamps, self.max_silent_period, self.max_merge_size, self.segment_padding_left, self.segment_padding_right)
119
 
120
  # A deque of transcribed segments that is passed to the next segment as a prompt
121
  prompt_window = deque()
@@ -123,18 +123,18 @@ class AbstractTranscription(ABC):
123
  print("Timestamps:")
124
  pprint(merged)
125
 
126
- if self.non_speech_strategy != NonSpeechStrategy.SKIP:
127
  max_audio_duration = get_audio_duration(audio)
128
 
129
  # Expand segments to include the gaps between them
130
- if (self.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
131
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
132
- merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=self.max_merge_size)
133
- elif self.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
134
  # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
135
  merged = self.expand_gaps(merged, total_duration=max_audio_duration)
136
  else:
137
- raise Exception("Unknown non-speech strategy: " + str(self.non_speech_strategy))
138
 
139
  print("Transcribing non-speech:")
140
  pprint(merged)
@@ -193,15 +193,15 @@ class AbstractTranscription(ABC):
193
  languageCounter[segment_result['language']] += 1
194
 
195
  # Update prompt window
196
- self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap)
197
 
198
  if detected_language is not None:
199
  result['language'] = detected_language
200
 
201
  return result
202
 
203
- def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool = False):
204
- if (self.max_prompt_window is not None and self.max_prompt_window > 0):
205
  # Add segments to the current prompt window (unless it is a speech gap)
206
  if not segment_gap:
207
  for segment in adjusted_segments:
@@ -213,7 +213,7 @@ class AbstractTranscription(ABC):
213
  # Time expanded in the segments should be discounted from the prompt window
214
  first_expand_time = prompt_window[0].get('expand_amount', 0)
215
 
216
- if (first_end_time - first_expand_time < segment_end - self.max_prompt_window):
217
  prompt_window.popleft()
218
  else:
219
  break
@@ -371,20 +371,14 @@ class AbstractTranscription(ABC):
371
  return result
372
 
373
  class VadSileroTranscription(AbstractTranscription):
374
- def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
375
- max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
376
- max_prompt_window=MAX_PROMPT_WINDOW, copy = None):
377
- super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
378
- max_silent_period=max_silent_period, max_merge_size=max_merge_size, non_speech_strategy=non_speech_strategy, max_prompt_window=max_prompt_window)
379
-
380
- if copy:
381
- self.model = copy.model
382
- self.get_speech_timestamps = copy.get_speech_timestamps
383
- else:
384
- self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
385
- (self.get_speech_timestamps, _, _, _, _) = utils
386
-
387
- def get_transcribe_timestamps(self, audio: str):
388
  audio_duration = get_audio_duration(audio)
389
  result = []
390
 
@@ -410,11 +404,10 @@ class VadSileroTranscription(AbstractTranscription):
410
 
411
  # A very simple VAD that just marks every N seconds as speech
412
  class VadPeriodicTranscription(AbstractTranscription):
413
- def __init__(self, periodic_duration: float):
414
- super().__init__()
415
- self.periodic_duration = periodic_duration
416
 
417
- def get_transcribe_timestamps(self, audio: str):
418
  # Get duration in seconds
419
  audio_duration = get_audio_duration(audio)
420
  result = []
@@ -423,7 +416,7 @@ class VadPeriodicTranscription(AbstractTranscription):
423
  start_timestamp = 0
424
 
425
  while (start_timestamp < audio_duration):
426
- end_timestamp = min(start_timestamp + self.periodic_duration, audio_duration)
427
  segment_duration = end_timestamp - start_timestamp
428
 
429
  # Minimum duration is 1 second
 
38
 
39
  # Defaults for Silero
40
  SPEECH_TRESHOLD = 0.3
 
 
 
 
 
 
41
 
42
  # Minimum size of segments to process
43
  MIN_SEGMENT_DURATION = 1
44
 
 
 
 
 
45
  # The maximum time for texts from old segments to be used in the next segment
46
  MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
47
  PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
48
 
49
  VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
50
 
51
+ class TranscriptionConfig(ABC):
52
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
53
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
54
+ max_merge_size: float = None, max_prompt_window: float = None):
55
+ self.non_speech_strategy = non_speech_strategy
56
  self.segment_padding_left = segment_padding_left
57
  self.segment_padding_right = segment_padding_right
58
  self.max_silent_period = max_silent_period
59
  self.max_merge_size = max_merge_size
 
60
  self.max_prompt_window = max_prompt_window
61
 
62
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
63
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
64
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
65
+ max_merge_size: float = None, max_prompt_window: float = None):
66
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
67
+ self.periodic_duration = periodic_duration
68
+
69
+ class AbstractTranscription(ABC):
70
+ def __init__(self, sampling_rate: int = 16000):
71
+ self.sampling_rate = sampling_rate
72
 
73
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
74
  return load_audio(str, self.sampling_rate, start_time, duration)
75
 
76
  @abstractmethod
77
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
78
  """
79
  Get the start and end timestamps of the sections that should be transcribed by this VAD method.
80
 
 
82
  ----------
83
  audio: str
84
  The audio file.
85
+ config: TranscriptionConfig
86
+ The transcription configuration.
87
 
88
  Returns
89
  -------
 
91
  """
92
  return
93
 
94
+ def transcribe(self, audio: str, whisperCallable, config: TranscriptionConfig):
95
  """
96
  Transcribe the given audo file.
97
 
 
110
  """
111
 
112
  # get speech timestamps from full audio file
113
+ seconds_timestamps = self.get_transcribe_timestamps(audio, config)
114
 
115
  #for seconds_timestamp in seconds_timestamps:
116
  # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
 
118
+ merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
119
 
120
  # A deque of transcribed segments that is passed to the next segment as a prompt
121
  prompt_window = deque()
 
123
  print("Timestamps:")
124
  pprint(merged)
125
 
126
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
  max_audio_duration = get_audio_duration(audio)
128
 
129
  # Expand segments to include the gaps between them
130
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
131
  # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
132
+ merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=config.max_merge_size)
133
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
134
  # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
135
  merged = self.expand_gaps(merged, total_duration=max_audio_duration)
136
  else:
137
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
138
 
139
  print("Transcribing non-speech:")
140
  pprint(merged)
 
193
  languageCounter[segment_result['language']] += 1
194
 
195
  # Update prompt window
196
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
197
 
198
  if detected_language is not None:
199
  result['language'] = detected_language
200
 
201
  return result
202
 
203
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
204
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
205
  # Add segments to the current prompt window (unless it is a speech gap)
206
  if not segment_gap:
207
  for segment in adjusted_segments:
 
213
  # Time expanded in the segments should be discounted from the prompt window
214
  first_expand_time = prompt_window[0].get('expand_amount', 0)
215
 
216
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
217
  prompt_window.popleft()
218
  else:
219
  break
 
371
  return result
372
 
373
  class VadSileroTranscription(AbstractTranscription):
374
+ def __init__(self, sampling_rate: int = 16000):
375
+ super().__init__(sampling_rate=sampling_rate)
376
+
377
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
378
+ (self.get_speech_timestamps, _, _, _, _) = utils
379
+
380
+
381
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
 
 
 
 
 
 
382
  audio_duration = get_audio_duration(audio)
383
  result = []
384
 
 
404
 
405
  # A very simple VAD that just marks every N seconds as speech
406
  class VadPeriodicTranscription(AbstractTranscription):
407
+ def __init__(self, sampling_rate: int = 16000):
408
+ super().__init__(sampling_rate=sampling_rate)
 
409
 
410
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
411
  # Get duration in seconds
412
  audio_duration = get_audio_duration(audio)
413
  result = []
 
416
  start_timestamp = 0
417
 
418
  while (start_timestamp < audio_duration):
419
+ end_timestamp = min(start_timestamp + config.periodic_duration, audio_duration)
420
  segment_duration = end_timestamp - start_timestamp
421
 
422
  # Minimum duration is 1 second