clone / app.py
nikkmitra's picture
Update app.py
1d877e3 verified
import gradio as gr
import torch
from TTS.api import TTS
import os
import spaces
import tempfile
from pymongo import MongoClient
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
# Load environment variables
load_dotenv()
# Get MongoDB URI and Hugging Face token from .env file
mongodb_uri = os.getenv('MONGODB_URI')
hf_token = os.getenv('HF_TOKEN')
# Connect to MongoDB
client = MongoClient(mongodb_uri)
db = client['mitra']
voices_collection = db['voices']
os.environ["COQUI_TOS_AGREED"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize TTS model
def load_tts_model():
return TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
tts = load_tts_model()
# Fetch celebrity voices from MongoDB
def get_celebrity_voices():
voices = {}
for category in voices_collection.find():
for voice in category['voices']:
voices[voice['name']] = f"voices/{voice['name']}.mp3"
return voices
celebrity_voices = get_celebrity_voices()
def check_voice_files():
"""
Checks if all voice files exist in the Hugging Face repository.
Returns a message listing missing files or confirming all files are present.
"""
missing = []
for voice, path in celebrity_voices.items():
try:
hf_hub_download(repo_id="nikkmitra/clone", filename=path, repo_type="space", token=hf_token)
except Exception:
missing.append(f"{voice}: {path}")
if missing:
return "**Missing Voice Files:**\n" + "\n".join(missing)
else:
return "**All voice files are present.** 🎉"
# New function to split text into chunks of 100 tokens using the Hindi tokenizer
def split_text_into_chunks(text, max_tokens=100, language="en"):
"""
Splits the input text into chunks with a maximum of `max_tokens` tokens each.
Inserts a newline after each chunk.
Uses a specialized tokenizer for Hindi language.
"""
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk = ' '.join(tokens[i:i + max_tokens])
chunks.append(chunk)
return '\n'.join(chunks)
@spaces.GPU(duration=120)
def tts_generate(text, voice, language):
# Check for Hindi language and split text if necessary
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio_path = temp_audio.name
try:
voice_file = hf_hub_download(repo_id="nikkmitra/clone", filename=celebrity_voices[voice], repo_type="space", token=hf_token)
except Exception as e:
return f"Error downloading voice file: {e}"
try:
tts.tts_to_file(
text=text,
speaker_wav=voice_file,
language=language,
file_path=temp_audio_path
)
except AssertionError as ae:
return f"Error: {ae}"
except Exception as e:
return f"An unexpected error occurred: {e}"
return temp_audio_path
@spaces.GPU(duration=120)
def clone_voice(text, audio_file, language):
print("cloning")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio_path = temp_audio.name
try:
tts.tts_to_file(
text=text,
speaker_wav=audio_file,
language=language,
file_path=temp_audio_path
)
except AssertionError as ae:
return f"Error: {ae}"
except Exception as e:
return f"An unexpected error occurred: {e}"
return temp_audio_path
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Advanced Voice Synthesis")
# Display voice files status
voice_status = check_voice_files()
gr.Markdown(voice_status)
with gr.Tabs():
with gr.TabItem("TTS"):
with gr.Row():
tts_text = gr.Textbox(label="Text to speak")
tts_voice = gr.Dropdown(choices=list(celebrity_voices.keys()), label="Celebrity Voice")
tts_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar","hi"], label="Language", value="en")
tts_generate_btn = gr.Button("Generate")
tts_output = gr.Audio(label="Generated Audio")
tts_generate_btn.click(
tts_generate,
inputs=[tts_text, tts_voice, tts_language],
outputs=tts_output
)
with gr.TabItem("Clone Voice"):
with gr.Row():
clone_text = gr.Textbox(label="Text to speak")
clone_audio = gr.Audio(label="Voice reference audio file", type="filepath")
clone_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar", "hi"], label="Language", value="en")
clone_generate_btn = gr.Button("Generate")
clone_output = gr.Audio(label="Generated Audio")
clone_generate_btn.click(
clone_voice,
inputs=[clone_text, clone_audio, clone_language],
outputs=clone_output
)
# Launch the interface
demo.launch()
# Clean up temporary files (this will run after the Gradio server is closed)
for file in os.listdir():
if file.endswith('.wav') and file.startswith('tmp'):
os.remove(file)