sanchit-gandhi commited on
Commit
5247fcf
·
1 Parent(s): fcd9ad1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -200,13 +200,16 @@ def transcribe(audio_path, task="transcribe", group_by_speaker=True, progress=gr
200
  )
201
 
202
  # run diarization while we wait for Whisper JAX
 
203
  diarization = diarization_pipeline(audio_path)
204
  segments = diarization.for_json()["content"]
205
 
206
  # only fetch the transcription result after performing diarization
 
207
  transcription, _ = job.result()
208
 
209
  # align the ASR transcriptions and diarization timestamps
 
210
  transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
211
 
212
  return transcription
@@ -222,6 +225,7 @@ def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.
222
  )
223
 
224
  html_embed_str = _return_yt_html_embed(yt_url)
 
225
  with tempfile.TemporaryDirectory() as tmpdirname:
226
  filepath = os.path.join(tmpdirname, "video.mp4")
227
  download_yt_audio(yt_url, filepath)
@@ -232,15 +236,19 @@ def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.
232
  inputs = torch.from_numpy(inputs).float()
233
  inputs = inputs.unsqueeze(0)
234
 
 
 
235
  diarization = diarization_pipeline(
236
  {"waveform": inputs, "sample_rate": SAMPLING_RATE},
237
  )
238
  segments = diarization.for_json()["content"]
239
 
240
  # only fetch the transcription result after performing diarization
 
241
  _, transcription, _ = job.result()
242
 
243
  # align the ASR transcriptions and diarization timestamps
 
244
  transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
245
 
246
  return html_embed_str, transcription
 
200
  )
201
 
202
  # run diarization while we wait for Whisper JAX
203
+ progress(0, desc="Diarizing...")
204
  diarization = diarization_pipeline(audio_path)
205
  segments = diarization.for_json()["content"]
206
 
207
  # only fetch the transcription result after performing diarization
208
+ progress(0.33, desc="Transcribing...")
209
  transcription, _ = job.result()
210
 
211
  # align the ASR transcriptions and diarization timestamps
212
+ progress(0.66, desc="Aligning...")
213
  transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
214
 
215
  return transcription
 
225
  )
226
 
227
  html_embed_str = _return_yt_html_embed(yt_url)
228
+ progress(0, desc="Downloading YouTube video...")
229
  with tempfile.TemporaryDirectory() as tmpdirname:
230
  filepath = os.path.join(tmpdirname, "video.mp4")
231
  download_yt_audio(yt_url, filepath)
 
236
  inputs = torch.from_numpy(inputs).float()
237
  inputs = inputs.unsqueeze(0)
238
 
239
+ # run diarization while we wait for Whisper JAX
240
+ progress(0.25, desc="Diarizing...")
241
  diarization = diarization_pipeline(
242
  {"waveform": inputs, "sample_rate": SAMPLING_RATE},
243
  )
244
  segments = diarization.for_json()["content"]
245
 
246
  # only fetch the transcription result after performing diarization
247
+ progress(0.50, desc="Transcribing...")
248
  _, transcription, _ = job.result()
249
 
250
  # align the ASR transcriptions and diarization timestamps
251
+ progress(0.75, desc="Aligning...")
252
  transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
253
 
254
  return html_embed_str, transcription