aadnk commited on
Commit
f1fe464
·
1 Parent(s): 31ba778

Ensure VAD supports detect language

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. src/vad.py +12 -7
app.py CHANGED
@@ -90,7 +90,8 @@ class WhisperTranscriber:
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
  # Callable for processing an audio file
93
- whisperCallable = lambda audio, prompt : model.transcribe(audio, language=language, task=task, initial_prompt=prompt, **decodeOptions)
 
94
 
95
  # The results
96
  if (vad == 'silero-vad'):
@@ -112,7 +113,7 @@ class WhisperTranscriber:
112
  result = periodic_vad.transcribe(audio_path, whisperCallable)
113
  else:
114
  # Default VAD
115
- result = whisperCallable(audio_path, None)
116
 
117
  return result
118
 
 
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
  # Callable for processing an audio file
93
+ whisperCallable = lambda audio, prompt, detected_language : model.transcribe(audio, \
94
+ language=language if language else detected_language, task=task, initial_prompt=prompt, **decodeOptions)
95
 
96
  # The results
97
  if (vad == 'silero-vad'):
 
113
  result = periodic_vad.transcribe(audio_path, whisperCallable)
114
  else:
115
  # Default VAD
116
+ result = whisperCallable(audio_path, None, None)
117
 
118
  return result
119
 
src/vad.py CHANGED
@@ -100,9 +100,9 @@ class AbstractTranscription(ABC):
100
  audio: str
101
  The audio file.
102
 
103
- whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], str], dict[str, Union[dict, Any]]]
104
  The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
- and the second parameter is an optional text prompt. The return value is the result of the Whisper call.
106
 
107
  Returns
108
  -------
@@ -145,6 +145,7 @@ class AbstractTranscription(ABC):
145
  'language': ""
146
  }
147
  languageCounter = Counter()
 
148
 
149
  # For each time segment, run whisper
150
  for segment in merged:
@@ -163,9 +164,12 @@ class AbstractTranscription(ABC):
163
  # Previous segments to use as a prompt
164
  segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
165
 
 
 
 
166
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
167
- segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt)
168
- segment_result = whisperCallable(segment_audio, segment_prompt)
169
 
170
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
171
 
@@ -185,13 +189,14 @@ class AbstractTranscription(ABC):
185
  result['segments'].extend(adjusted_segments)
186
 
187
  # Increment detected language
188
- languageCounter[segment_result['language']] += 1
 
189
 
190
  # Update prompt window
191
  self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap)
192
 
193
- if len(languageCounter) > 0:
194
- result['language'] = languageCounter.most_common(1)[0][0]
195
 
196
  return result
197
 
 
100
  audio: str
101
  The audio file.
102
 
103
+ whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], str, str], dict[str, Union[dict, Any]]]
104
  The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
+ the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
 
107
  Returns
108
  -------
 
145
  'language': ""
146
  }
147
  languageCounter = Counter()
148
+ detected_language = None
149
 
150
  # For each time segment, run whisper
151
  for segment in merged:
 
164
  # Previous segments to use as a prompt
165
  segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
166
 
167
+ # Detected language
168
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
169
+
170
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
171
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
172
+ segment_result = whisperCallable(segment_audio, segment_prompt, detected_language)
173
 
174
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
175
 
 
189
  result['segments'].extend(adjusted_segments)
190
 
191
  # Increment detected language
192
+ if not segment_gap:
193
+ languageCounter[segment_result['language']] += 1
194
 
195
  # Update prompt window
196
  self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap)
197
 
198
+ if detected_language is not None:
199
+ result['language'] = detected_language
200
 
201
  return result
202