open-notebooklm / utils.py
gabriel chua
Chore: Clean up code (#2)
ab25593 unverified
raw
history blame
7.16 kB
"""
utils.py
Functions:
- generate_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- parse_url: Parse the given URL and return the text content.
- generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
"""
# Standard library imports
import time
from typing import Any, Union
# Third-party imports
import requests
from bark import SAMPLE_RATE, generate_audio, preload_models
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from scipy.io.wavfile import write as write_wav
# Local imports
from constants import (
FIREWORKS_API_KEY,
FIREWORKS_BASE_URL,
FIREWORKS_MODEL_ID,
FIREWORKS_MAX_TOKENS,
FIREWORKS_TEMPERATURE,
FIREWORKS_JSON_RETRY_ATTEMPTS,
MELO_API_NAME,
MELO_TTS_SPACES_ID,
MELO_RETRY_ATTEMPTS,
MELO_RETRY_DELAY,
JINA_READER_URL,
JINA_RETRY_ATTEMPTS,
JINA_RETRY_DELAY,
)
from schema import ShortDialogue, MediumDialogue
# Initialize clients
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
hf_client = Client(MELO_TTS_SPACES_ID)
# Download and load all models for Bark
preload_models()
def generate_script(
system_prompt: str,
input_text: str,
output_model: Union[ShortDialogue, MediumDialogue],
) -> Union[ShortDialogue, MediumDialogue]:
"""Get the dialogue from the LLM."""
# Call the LLM
response = call_llm(system_prompt, input_text, output_model)
response_json = response.choices[0].message.content
# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
first_draft_dialogue = output_model.model_validate_json(response_json)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = (
f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
)
# Re-call the LLM with the error message
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
response_json = response.choices[0].message.content
first_draft_dialogue = output_model.model_validate_json(response_json)
# Call the LLM a second time to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
response = call_llm(
system_prompt_with_dialogue,
"Please improve the dialogue. Make it more natural and engaging.",
output_model,
)
final_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
return final_dialogue
def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
"""Call the LLM with the given prompt and dialogue format."""
response = fw_client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=FIREWORKS_MODEL_ID,
max_tokens=FIREWORKS_MAX_TOKENS,
temperature=FIREWORKS_TEMPERATURE,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
for attempt in range(JINA_RETRY_ATTEMPTS):
try:
full_url = f"{JINA_READER_URL}{url}"
response = requests.get(full_url, timeout=60)
response.raise_for_status() # Raise an exception for bad status codes
break
except requests.RequestException as e:
if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}"
) from e
time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying
return response.text
def generate_podcast_audio(
text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int
) -> str:
"""Generate audio for podcast using TTS or advanced audio models."""
if use_advanced_audio:
return _use_suno_model(text, speaker, language, random_voice_number)
else:
return _use_melotts_api(text, speaker, language)
def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str:
"""Generate advanced audio using Bark."""
audio_array = generate_audio(
text,
history_prompt=f"v2/{language}_speaker_{random_voice_number if speaker == 'Host (Jane)' else random_voice_number + 1}",
)
file_path = f"audio_{language}_{speaker}.mp3"
write_wav(file_path, SAMPLE_RATE, audio_array)
return file_path
def _use_melotts_api(text: str, speaker: str, language: str) -> str:
"""Generate audio using TTS model."""
accent, speed = _get_melo_tts_params(speaker, language)
for attempt in range(MELO_RETRY_ATTEMPTS):
try:
return hf_client.predict(
text=text,
language=language,
speaker=accent,
speed=speed,
api_name=MELO_API_NAME,
)
except Exception as e:
if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt
raise # Re-raise the last exception if all attempts fail
time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying
def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]:
"""Get TTS parameters based on speaker and language."""
if speaker == "Guest":
accent = "EN-US" if language == "EN" else language
speed = 0.9
else: # host
accent = "EN-Default" if language == "EN" else language
speed = (
1.1 if language != "EN" else 1
) # if the language is not English, try speeding up so it'll sound different from the host
# for non-English, there is only one voice
return accent, speed