aadnk commited on
Commit
84fa1f8
·
1 Parent(s): 5bbbb16

Concat first prompt with initial prompt

Browse files
Files changed (2) hide show
  1. app.py +19 -3
  2. src/vad.py +5 -2
app.py CHANGED
@@ -89,9 +89,17 @@ class WhisperTranscriber:
89
 
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
 
 
 
 
 
 
92
  # Callable for processing an audio file
93
- whisperCallable = lambda audio, prompt, detected_language : model.transcribe(audio, \
94
- language=language if language else detected_language, task=task, initial_prompt=prompt, **decodeOptions)
 
 
95
 
96
  # The results
97
  if (vad == 'silero-vad'):
@@ -113,10 +121,18 @@ class WhisperTranscriber:
113
  result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
114
  else:
115
  # Default VAD
116
- result = whisperCallable(audio_path, None, None)
117
 
118
  return result
119
 
 
 
 
 
 
 
 
 
120
  def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
121
  # Use Silero VAD
122
  if (self.vad_model is None):
 
89
 
90
  def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
+
93
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
94
+
95
+ if ('task' in decodeOptions):
96
+ task = decodeOptions.pop('task')
97
+
98
  # Callable for processing an audio file
99
+ whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
100
+ language=language if language else detected_language, task=task, \
101
+ initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
102
+ **decodeOptions)
103
 
104
  # The results
105
  if (vad == 'silero-vad'):
 
121
  result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
122
  else:
123
  # Default VAD
124
+ result = whisperCallable(audio_path, 0, None, None)
125
 
126
  return result
127
 
128
+ def _concat_prompt(self, prompt1, prompt2):
129
+ if (prompt1 is None):
130
+ return prompt2
131
+ elif (prompt2 is None):
132
+ return prompt1
133
+ else:
134
+ return prompt1 + " " + prompt2
135
+
136
  def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
137
  # Use Silero VAD
138
  if (self.vad_model is None):
src/vad.py CHANGED
@@ -100,7 +100,7 @@ class AbstractTranscription(ABC):
100
  audio: str
101
  The audio file.
102
 
103
- whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], str, str], dict[str, Union[dict, Any]]]
104
  The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
  the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
 
@@ -147,8 +147,11 @@ class AbstractTranscription(ABC):
147
  languageCounter = Counter()
148
  detected_language = None
149
 
 
 
150
  # For each time segment, run whisper
151
  for segment in merged:
 
152
  segment_start = segment['start']
153
  segment_end = segment['end']
154
  segment_expand_amount = segment.get('expand_amount', 0)
@@ -169,7 +172,7 @@ class AbstractTranscription(ABC):
169
 
170
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
171
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
172
- segment_result = whisperCallable(segment_audio, segment_prompt, detected_language)
173
 
174
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
175
 
 
100
  audio: str
101
  The audio file.
102
 
103
+ whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], int, str, str], dict[str, Union[dict, Any]]]
104
  The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
  the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
 
 
147
  languageCounter = Counter()
148
  detected_language = None
149
 
150
+ segment_index = -1
151
+
152
  # For each time segment, run whisper
153
  for segment in merged:
154
+ segment_index += 1
155
  segment_start = segment['start']
156
  segment_end = segment['end']
157
  segment_expand_amount = segment.get('expand_amount', 0)
 
172
 
173
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
174
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
175
+ segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
176
 
177
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
178