Spaces:
Runtime error
Runtime error
sanchit-gandhi
commited on
Commit
·
5247fcf
1
Parent(s):
fcd9ad1
Update app.py
Browse files
app.py
CHANGED
@@ -200,13 +200,16 @@ def transcribe(audio_path, task="transcribe", group_by_speaker=True, progress=gr
|
|
200 |
)
|
201 |
|
202 |
# run diarization while we wait for Whisper JAX
|
|
|
203 |
diarization = diarization_pipeline(audio_path)
|
204 |
segments = diarization.for_json()["content"]
|
205 |
|
206 |
# only fetch the transcription result after performing diarization
|
|
|
207 |
transcription, _ = job.result()
|
208 |
|
209 |
# align the ASR transcriptions and diarization timestamps
|
|
|
210 |
transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
|
211 |
|
212 |
return transcription
|
@@ -222,6 +225,7 @@ def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.
|
|
222 |
)
|
223 |
|
224 |
html_embed_str = _return_yt_html_embed(yt_url)
|
|
|
225 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
226 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
227 |
download_yt_audio(yt_url, filepath)
|
@@ -232,15 +236,19 @@ def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.
|
|
232 |
inputs = torch.from_numpy(inputs).float()
|
233 |
inputs = inputs.unsqueeze(0)
|
234 |
|
|
|
|
|
235 |
diarization = diarization_pipeline(
|
236 |
{"waveform": inputs, "sample_rate": SAMPLING_RATE},
|
237 |
)
|
238 |
segments = diarization.for_json()["content"]
|
239 |
|
240 |
# only fetch the transcription result after performing diarization
|
|
|
241 |
_, transcription, _ = job.result()
|
242 |
|
243 |
# align the ASR transcriptions and diarization timestamps
|
|
|
244 |
transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
|
245 |
|
246 |
return html_embed_str, transcription
|
|
|
200 |
)
|
201 |
|
202 |
# run diarization while we wait for Whisper JAX
|
203 |
+
progress(0, desc="Diarizing...")
|
204 |
diarization = diarization_pipeline(audio_path)
|
205 |
segments = diarization.for_json()["content"]
|
206 |
|
207 |
# only fetch the transcription result after performing diarization
|
208 |
+
progress(0.33, desc="Transcribing...")
|
209 |
transcription, _ = job.result()
|
210 |
|
211 |
# align the ASR transcriptions and diarization timestamps
|
212 |
+
progress(0.66, desc="Aligning...")
|
213 |
transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
|
214 |
|
215 |
return transcription
|
|
|
225 |
)
|
226 |
|
227 |
html_embed_str = _return_yt_html_embed(yt_url)
|
228 |
+
progress(0, desc="Downloading YouTube video...")
|
229 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
230 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
231 |
download_yt_audio(yt_url, filepath)
|
|
|
236 |
inputs = torch.from_numpy(inputs).float()
|
237 |
inputs = inputs.unsqueeze(0)
|
238 |
|
239 |
+
# run diarization while we wait for Whisper JAX
|
240 |
+
progress(0.25, desc="Diarizing...")
|
241 |
diarization = diarization_pipeline(
|
242 |
{"waveform": inputs, "sample_rate": SAMPLING_RATE},
|
243 |
)
|
244 |
segments = diarization.for_json()["content"]
|
245 |
|
246 |
# only fetch the transcription result after performing diarization
|
247 |
+
progress(0.50, desc="Transcribing...")
|
248 |
_, transcription, _ = job.result()
|
249 |
|
250 |
# align the ASR transcriptions and diarization timestamps
|
251 |
+
progress(0.75, desc="Aligning...")
|
252 |
transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
|
253 |
|
254 |
return html_embed_str, transcription
|