saq1b commited on
Commit
aa19ef5
·
verified ·
1 Parent(s): 38a4215

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -57
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  from pydub import AudioSegment
3
  from google import genai
4
  from google.genai import types
@@ -284,7 +284,7 @@ Follow this example structure:
284
 
285
  try:
286
  if progress:
287
- progress(0.3, "Generating podcast script...")
288
 
289
  # Add timeout to the API call
290
  response = await asyncio.wait_for(
@@ -306,19 +306,19 @@ Follow this example structure:
306
  timeout=60 # 60 seconds timeout
307
  )
308
  except asyncio.TimeoutError:
309
- raise gr.Error("The script generation request timed out. Please try again later.")
310
  except Exception as e:
311
  if "API key not valid" in str(e):
312
- raise gr.Error("Invalid API key. Please provide a valid Gemini API key.")
313
  elif "rate limit" in str(e).lower():
314
- raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own Gemini API key.")
315
  else:
316
- raise gr.Error(f"Failed to generate podcast script: {e}")
317
 
318
  print(f"Generated podcast script:\n{response.text}")
319
 
320
  if progress:
321
- progress(0.4, "Script generated successfully!")
322
 
323
  return json.loads(response.text)
324
 
@@ -327,7 +327,7 @@ Follow this example structure:
327
  # Check file size before reading
328
  file_size = os.path.getsize(file_obj.name)
329
  if file_size > MAX_FILE_SIZE_BYTES:
330
- raise gr.Error(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
331
 
332
  async with aiofiles.open(file_obj.name, 'rb') as f:
333
  return await f.read()
@@ -356,7 +356,7 @@ Follow this example structure:
356
  except asyncio.TimeoutError:
357
  if os.path.exists(temp_filename):
358
  os.remove(temp_filename)
359
- raise gr.Error("Text-to-speech generation timed out. Please try with a shorter text.")
360
  except Exception as e:
361
  if os.path.exists(temp_filename):
362
  os.remove(temp_filename)
@@ -364,7 +364,7 @@ Follow this example structure:
364
 
365
  async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
366
  if progress:
367
- progress(0.9, "Combining audio files...")
368
 
369
  combined_audio = AudioSegment.empty()
370
  for audio_file in audio_files:
@@ -375,14 +375,14 @@ Follow this example structure:
375
  combined_audio.export(output_filename, format="wav")
376
 
377
  if progress:
378
- progress(1.0, "Podcast generated successfully!")
379
 
380
  return output_filename
381
 
382
  async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
383
  try:
384
  if progress:
385
- progress(0.1, "Starting podcast generation...")
386
 
387
  # Set overall timeout for the entire process
388
  return await asyncio.wait_for(
@@ -390,18 +390,18 @@ Follow this example structure:
390
  timeout=600 # 10 minutes total timeout
391
  )
392
  except asyncio.TimeoutError:
393
- raise gr.Error("The podcast generation process timed out. Please try with shorter text or try again later.")
394
  except Exception as e:
395
- raise gr.Error(f"Error generating podcast: {str(e)}")
396
 
397
  async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
398
  if progress:
399
- progress(0.2, "Generating podcast script...")
400
 
401
  podcast_json = await self.generate_script(input_text, language, api_key, file_obj, progress)
402
 
403
  if progress:
404
- progress(0.5, "Converting text to speech...")
405
 
406
  # Process TTS in batches to prevent overwhelming the system
407
  audio_files = []
@@ -410,7 +410,7 @@ Follow this example structure:
410
  for i, item in enumerate(podcast_json['podcast']):
411
  if progress:
412
  current_progress = 0.5 + (0.4 * (i / total_lines))
413
- progress(current_progress, f"Processing speech {i+1}/{total_lines}...")
414
 
415
  try:
416
  audio_file = await self.tts_generate(item['line'], item['speaker'], speaker1, speaker2)
@@ -420,12 +420,12 @@ Follow this example structure:
420
  for file in audio_files:
421
  if os.path.exists(file):
422
  os.remove(file)
423
- raise gr.Error(f"Error generating speech for line {i+1}: {str(e)}")
424
 
425
  combined_audio = await self.combine_audio_files(audio_files, progress)
426
  return combined_audio
427
-
428
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "", progress=gr.Progress()) -> str:
429
  start_time = time.time()
430
 
431
  voice_names = {
@@ -443,12 +443,13 @@ async def process_input(input_text: str, input_file, language: str, speaker1: st
443
  speaker2 = voice_names[speaker2]
444
 
445
  try:
446
- progress(0.05, "Processing input...")
 
447
 
448
  if not api_key:
449
  api_key = os.getenv("GENAI_API_KEY")
450
  if not api_key:
451
- raise gr.Error("No API key provided. Please provide a Gemini API key.")
452
 
453
  podcast_generator = PodcastGenerator()
454
  podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key, input_file, progress)
@@ -461,18 +462,25 @@ async def process_input(input_text: str, input_file, language: str, speaker1: st
461
  # Ensure we show a user-friendly error
462
  error_msg = str(e)
463
  if "rate limit" in error_msg.lower():
464
- raise gr.Error("Rate limit exceeded. Please try again later or use your own API key.")
465
  elif "timeout" in error_msg.lower():
466
- raise gr.Error("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.")
467
  else:
468
- raise gr.Error(f"Error: {error_msg}")
469
 
470
- iface = gr.Interface(
471
- fn=process_input,
472
- inputs=[
473
- gr.Textbox(label="Input Text"),
474
- gr.File(label="Or Upload a PDF or TXT file"),
475
- gr.Dropdown(label="Language", choices=[
 
 
 
 
 
 
 
476
  "Auto Detect",
477
  "Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani",
478
  "Bahasa Indonesian", "Bangla", "Basque", "Bengali", "Bosnian", "Bulgarian",
@@ -487,20 +495,10 @@ iface = gr.Interface(
487
  "Slovak", "Slovene", "Somali", "Spanish", "Sundanese", "Swahili",
488
  "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian", "Urdu",
489
  "Uzbek", "Vietnamese", "Welsh", "Zulu"
490
- ],
491
- value="Auto Detect"),
492
- gr.Dropdown(label="Speaker 1 Voice", choices=[
493
- "Andrew - English (United States)",
494
- "Ava - English (United States)",
495
- "Brian - English (United States)",
496
- "Emma - English (United States)",
497
- "Florian - German (Germany)",
498
- "Seraphina - German (Germany)",
499
- "Remy - French (France)",
500
- "Vivienne - French (France)"
501
- ],
502
- value="Andrew - English (United States)"),
503
- gr.Dropdown(label="Speaker 2 Voice", choices=[
504
  "Andrew - English (United States)",
505
  "Ava - English (United States)",
506
  "Brian - English (United States)",
@@ -509,17 +507,95 @@ iface = gr.Interface(
509
  "Seraphina - German (Germany)",
510
  "Remy - French (France)",
511
  "Vivienne - French (France)"
512
- ],
513
- value="Ava - English (United States)"),
514
- gr.Textbox(label="Your Gemini API Key (Optional) - In case you are getting rate limited"),
515
- ],
516
- outputs=[
517
- gr.Audio(label="Generated Podcast Audio")
518
- ],
519
- title="PodcastGen 🎙️",
520
- description="Generate a 2-speaker podcast from text input or documents!",
521
- allow_flagging="never",
522
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  if __name__ == "__main__":
525
- iface.launch()
 
1
+ import streamlit as st
2
  from pydub import AudioSegment
3
  from google import genai
4
  from google.genai import types
 
284
 
285
  try:
286
  if progress:
287
+ progress.progress(0.3, "Generating podcast script...")
288
 
289
  # Add timeout to the API call
290
  response = await asyncio.wait_for(
 
306
  timeout=60 # 60 seconds timeout
307
  )
308
  except asyncio.TimeoutError:
309
+ raise Exception("The script generation request timed out. Please try again later.")
310
  except Exception as e:
311
  if "API key not valid" in str(e):
312
+ raise Exception("Invalid API key. Please provide a valid Gemini API key.")
313
  elif "rate limit" in str(e).lower():
314
+ raise Exception("Rate limit exceeded for the API key. Please try again later or provide your own Gemini API key.")
315
  else:
316
+ raise Exception(f"Failed to generate podcast script: {e}")
317
 
318
  print(f"Generated podcast script:\n{response.text}")
319
 
320
  if progress:
321
+ progress.progress(0.4, "Script generated successfully!")
322
 
323
  return json.loads(response.text)
324
 
 
327
  # Check file size before reading
328
  file_size = os.path.getsize(file_obj.name)
329
  if file_size > MAX_FILE_SIZE_BYTES:
330
+ raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
331
 
332
  async with aiofiles.open(file_obj.name, 'rb') as f:
333
  return await f.read()
 
356
  except asyncio.TimeoutError:
357
  if os.path.exists(temp_filename):
358
  os.remove(temp_filename)
359
+ raise Exception("Text-to-speech generation timed out. Please try with a shorter text.")
360
  except Exception as e:
361
  if os.path.exists(temp_filename):
362
  os.remove(temp_filename)
 
364
 
365
  async def combine_audio_files(self, audio_files: List[str], progress=None) -> str:
366
  if progress:
367
+ progress.progress(0.9, "Combining audio files...")
368
 
369
  combined_audio = AudioSegment.empty()
370
  for audio_file in audio_files:
 
375
  combined_audio.export(output_filename, format="wav")
376
 
377
  if progress:
378
+ progress.progress(1.0, "Podcast generated successfully!")
379
 
380
  return output_filename
381
 
382
  async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
383
  try:
384
  if progress:
385
+ progress.progress(0.1, "Starting podcast generation...")
386
 
387
  # Set overall timeout for the entire process
388
  return await asyncio.wait_for(
 
390
  timeout=600 # 10 minutes total timeout
391
  )
392
  except asyncio.TimeoutError:
393
+ raise Exception("The podcast generation process timed out. Please try with shorter text or try again later.")
394
  except Exception as e:
395
+ raise Exception(f"Error generating podcast: {str(e)}")
396
 
397
  async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str:
398
  if progress:
399
+ progress.progress(0.2, "Generating podcast script...")
400
 
401
  podcast_json = await self.generate_script(input_text, language, api_key, file_obj, progress)
402
 
403
  if progress:
404
+ progress.progress(0.5, "Converting text to speech...")
405
 
406
  # Process TTS in batches to prevent overwhelming the system
407
  audio_files = []
 
410
  for i, item in enumerate(podcast_json['podcast']):
411
  if progress:
412
  current_progress = 0.5 + (0.4 * (i / total_lines))
413
+ progress.progress(current_progress, f"Processing speech {i+1}/{total_lines}...")
414
 
415
  try:
416
  audio_file = await self.tts_generate(item['line'], item['speaker'], speaker1, speaker2)
 
420
  for file in audio_files:
421
  if os.path.exists(file):
422
  os.remove(file)
423
+ raise Exception(f"Error generating speech for line {i+1}: {str(e)}")
424
 
425
  combined_audio = await self.combine_audio_files(audio_files, progress)
426
  return combined_audio
427
+
428
+ async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "", progress=None) -> str:
429
  start_time = time.time()
430
 
431
  voice_names = {
 
443
  speaker2 = voice_names[speaker2]
444
 
445
  try:
446
+ if progress:
447
+ progress.progress(0.05, "Processing input...")
448
 
449
  if not api_key:
450
  api_key = os.getenv("GENAI_API_KEY")
451
  if not api_key:
452
+ raise Exception("No API key provided. Please provide a Gemini API key.")
453
 
454
  podcast_generator = PodcastGenerator()
455
  podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key, input_file, progress)
 
462
  # Ensure we show a user-friendly error
463
  error_msg = str(e)
464
  if "rate limit" in error_msg.lower():
465
+ raise Exception("Rate limit exceeded. Please try again later or use your own API key.")
466
  elif "timeout" in error_msg.lower():
467
+ raise Exception("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.")
468
  else:
469
+ raise Exception(f"Error: {error_msg}")
470
 
471
+ # Streamlit UI
472
+ def main():
473
+ st.set_page_config(page_title="PodcastGen 🎙️", page_icon="🎙️", layout="wide")
474
+
475
+ st.title("PodcastGen 🎙️")
476
+ st.write("Generate a 2-speaker podcast from text input or documents!")
477
+
478
+ with st.sidebar:
479
+ st.header("Configuration")
480
+ api_key = st.text_input("Your Gemini API Key (Optional)", type="password",
481
+ help="In case you are getting rate limited")
482
+
483
+ language_options = [
484
  "Auto Detect",
485
  "Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani",
486
  "Bahasa Indonesian", "Bangla", "Basque", "Bengali", "Bosnian", "Bulgarian",
 
495
  "Slovak", "Slovene", "Somali", "Spanish", "Sundanese", "Swahili",
496
  "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian", "Urdu",
497
  "Uzbek", "Vietnamese", "Welsh", "Zulu"
498
+ ]
499
+ language = st.selectbox("Language", language_options, index=0)
500
+
501
+ voice_options = [
 
 
 
 
 
 
 
 
 
 
502
  "Andrew - English (United States)",
503
  "Ava - English (United States)",
504
  "Brian - English (United States)",
 
507
  "Seraphina - German (Germany)",
508
  "Remy - French (France)",
509
  "Vivienne - French (France)"
510
+ ]
511
+ speaker1 = st.selectbox("Speaker 1 Voice", voice_options, index=0)
512
+ speaker2 = st.selectbox("Speaker 2 Voice", voice_options, index=1)
513
+
514
+ col1, col2 = st.columns([2, 1])
515
+
516
+ with col1:
517
+ input_text = st.text_area("Input Text", height=250)
518
+
519
+ with col2:
520
+ uploaded_file = st.file_uploader("Or Upload a PDF or TXT file", type=["pdf", "txt"])
521
+
522
+ if st.button("Generate Podcast"):
523
+ if not input_text and not uploaded_file:
524
+ st.error("Please provide either input text or upload a file.")
525
+ return
526
+
527
+ # Create a progress bar for the async operation
528
+ progress_bar = st.progress(0)
529
+ status_text = st.empty()
530
+
531
+ # Create a progress wrapper for compatibility with the existing code
532
+ class StreamlitProgress:
533
+ def progress(self, value, text=None):
534
+ progress_bar.progress(value)
535
+ if text:
536
+ status_text.text(text)
537
+
538
+ try:
539
+ # Prepare file if uploaded
540
+ file_obj = None
541
+ if uploaded_file:
542
+ # Save the uploaded file to a temporary location
543
+ file_path = f"temp_upload_{uuid.uuid4()}{os.path.splitext(uploaded_file.name)[1]}"
544
+ with open(file_path, "wb") as f:
545
+ f.write(uploaded_file.getbuffer())
546
+
547
+ class FileWrapper:
548
+ def __init__(self, path, name):
549
+ self.name = name
550
+ self.path = path
551
+
552
+ @property
553
+ def name(self):
554
+ return self._name
555
+
556
+ @name.setter
557
+ def name(self, value):
558
+ self._name = value
559
+
560
+ file_obj = FileWrapper(file_path, uploaded_file.name)
561
+ file_obj.name = file_path # Set the path as the name for proper file reading
562
+
563
+ # Run the async function in a new event loop
564
+ progress_wrapper = StreamlitProgress()
565
+ audio_file = asyncio.run(process_input(
566
+ input_text,
567
+ file_obj,
568
+ language,
569
+ speaker1,
570
+ speaker2,
571
+ api_key,
572
+ progress_wrapper
573
+ ))
574
+
575
+ # Display the audio
576
+ st.subheader("Generated Podcast")
577
+ st.audio(audio_file, format="audio/wav")
578
+
579
+ # Provide a download button
580
+ with open(audio_file, "rb") as f:
581
+ audio_bytes = f.read()
582
+
583
+ st.download_button(
584
+ label="Download Podcast",
585
+ data=audio_bytes,
586
+ file_name="podcast.wav",
587
+ mime="audio/wav"
588
+ )
589
+
590
+ # Clean up the temporary file
591
+ if file_obj:
592
+ try:
593
+ os.remove(file_path)
594
+ except:
595
+ pass
596
+
597
+ except Exception as e:
598
+ st.error(str(e))
599
 
600
  if __name__ == "__main__":
601
+ main()