Spaces:
Running
on
T4
Running
on
T4
""" | |
utils.py | |
Functions: | |
- generate_script: Get the dialogue from the LLM. | |
- call_llm: Call the LLM with the given prompt and dialogue format. | |
- parse_url: Parse the given URL and return the text content. | |
- generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models. | |
""" | |
# Standard library imports | |
import time | |
from typing import Any, Union | |
# Third-party imports | |
import requests | |
from bark import SAMPLE_RATE, generate_audio, preload_models | |
from gradio_client import Client | |
from openai import OpenAI | |
from pydantic import ValidationError | |
from scipy.io.wavfile import write as write_wav | |
# Local imports | |
from constants import ( | |
FIREWORKS_API_KEY, | |
FIREWORKS_BASE_URL, | |
FIREWORKS_MODEL_ID, | |
FIREWORKS_MAX_TOKENS, | |
FIREWORKS_TEMPERATURE, | |
FIREWORKS_JSON_RETRY_ATTEMPTS, | |
MELO_API_NAME, | |
MELO_TTS_SPACES_ID, | |
MELO_RETRY_ATTEMPTS, | |
MELO_RETRY_DELAY, | |
JINA_READER_URL, | |
JINA_RETRY_ATTEMPTS, | |
JINA_RETRY_DELAY, | |
) | |
from schema import ShortDialogue, MediumDialogue | |
# Initialize clients | |
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY) | |
hf_client = Client(MELO_TTS_SPACES_ID) | |
# Download and load all models for Bark | |
preload_models() | |
def generate_script( | |
system_prompt: str, | |
input_text: str, | |
output_model: Union[ShortDialogue, MediumDialogue], | |
) -> Union[ShortDialogue, MediumDialogue]: | |
"""Get the dialogue from the LLM.""" | |
# Call the LLM | |
response = call_llm(system_prompt, input_text, output_model) | |
response_json = response.choices[0].message.content | |
# Validate the response | |
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS): | |
try: | |
first_draft_dialogue = output_model.model_validate_json(response_json) | |
break | |
except ValidationError as e: | |
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt | |
raise ValueError( | |
f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}" | |
) from e | |
error_message = ( | |
f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}" | |
) | |
# Re-call the LLM with the error message | |
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}" | |
response = call_llm(system_prompt_with_error, input_text, output_model) | |
response_json = response.choices[0].message.content | |
first_draft_dialogue = output_model.model_validate_json(response_json) | |
# Call the LLM a second time to improve the dialogue | |
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}." | |
# Validate the response | |
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS): | |
try: | |
response = call_llm( | |
system_prompt_with_dialogue, | |
"Please improve the dialogue. Make it more natural and engaging.", | |
output_model, | |
) | |
final_dialogue = output_model.model_validate_json( | |
response.choices[0].message.content | |
) | |
break | |
except ValidationError as e: | |
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt | |
raise ValueError( | |
f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}" | |
) from e | |
error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}" | |
system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}" | |
return final_dialogue | |
def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any: | |
"""Call the LLM with the given prompt and dialogue format.""" | |
response = fw_client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": text}, | |
], | |
model=FIREWORKS_MODEL_ID, | |
max_tokens=FIREWORKS_MAX_TOKENS, | |
temperature=FIREWORKS_TEMPERATURE, | |
response_format={ | |
"type": "json_object", | |
"schema": dialogue_format.model_json_schema(), | |
}, | |
) | |
return response | |
def parse_url(url: str) -> str: | |
"""Parse the given URL and return the text content.""" | |
for attempt in range(JINA_RETRY_ATTEMPTS): | |
try: | |
full_url = f"{JINA_READER_URL}{url}" | |
response = requests.get(full_url, timeout=60) | |
response.raise_for_status() # Raise an exception for bad status codes | |
break | |
except requests.RequestException as e: | |
if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt | |
raise ValueError( | |
f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}" | |
) from e | |
time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying | |
return response.text | |
def generate_podcast_audio( | |
text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int | |
) -> str: | |
"""Generate audio for podcast using TTS or advanced audio models.""" | |
if use_advanced_audio: | |
return _use_suno_model(text, speaker, language, random_voice_number) | |
else: | |
return _use_melotts_api(text, speaker, language) | |
def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str: | |
"""Generate advanced audio using Bark.""" | |
audio_array = generate_audio( | |
text, | |
history_prompt=f"v2/{language}_speaker_{random_voice_number if speaker == 'Host (Jane)' else random_voice_number + 1}", | |
) | |
file_path = f"audio_{language}_{speaker}.mp3" | |
write_wav(file_path, SAMPLE_RATE, audio_array) | |
return file_path | |
def _use_melotts_api(text: str, speaker: str, language: str) -> str: | |
"""Generate audio using TTS model.""" | |
accent, speed = _get_melo_tts_params(speaker, language) | |
for attempt in range(MELO_RETRY_ATTEMPTS): | |
try: | |
return hf_client.predict( | |
text=text, | |
language=language, | |
speaker=accent, | |
speed=speed, | |
api_name=MELO_API_NAME, | |
) | |
except Exception as e: | |
if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt | |
raise # Re-raise the last exception if all attempts fail | |
time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying | |
def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]: | |
"""Get TTS parameters based on speaker and language.""" | |
if speaker == "Guest": | |
accent = "EN-US" if language == "EN" else language | |
speed = 0.9 | |
else: # host | |
accent = "EN-Default" if language == "EN" else language | |
speed = ( | |
1.1 if language != "EN" else 1 | |
) # if the language is not English, try speeding up so it'll sound different from the host | |
# for non-English, there is only one voice | |
return accent, speed | |