clone / app.py
nikkmitra's picture
Update app.py
2ea96fb verified
raw
history blame
4.25 kB
import gradio as gr
import torch
from TTS.api import TTS
import os
import spaces
import tempfile
from pymongo import MongoClient
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
# Load environment variables
load_dotenv()
# Get MongoDB URI and Hugging Face token from .env file
mongodb_uri = os.getenv('MONGODB_URI')
hf_token = os.getenv('HF_TOKEN')
# Connect to MongoDB
client = MongoClient(mongodb_uri)
db = client['mitra']
voices_collection = db['voices']
os.environ["COQUI_TOS_AGREED"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize TTS model
def load_tts_model():
return TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
tts = load_tts_model()
# Fetch celebrity voices from MongoDB
def get_celebrity_voices():
voices = {}
for category in voices_collection.find():
for voice in category['voices']:
voices[voice['name']] = f"voices/{voice['name']}.mp3"
return voices
celebrity_voices = get_celebrity_voices()
def check_voice_files():
"""
Checks if all voice files exist in the Hugging Face repository.
Returns a message listing missing files or confirming all files are present.
"""
missing = []
for voice, path in celebrity_voices.items():
try:
hf_hub_download(repo_id="nikkmitra/clone", filename=path, repo_type="space", token=hf_token)
except Exception:
missing.append(f"{voice}: {path}")
if missing:
return "**Missing Voice Files:**\n" + "\n".join(missing)
else:
return "**All voice files are present.** 🎉"
@spaces.GPU(duration=120)
def tts_generate(text, voice, language):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio_path = temp_audio.name
voice_file = hf_hub_download(repo_id="nikkmitra/clone", filename=celebrity_voices[voice], repo_type="space", token=hf_token)
tts.tts_to_file(
text=text,
speaker_wav=voice_file,
language=language,
file_path=temp_audio_path
)
return temp_audio_path
@spaces.GPU(enable_queue=True)
def clone_voice(text, audio_file, language):
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio_path = temp_audio.name
tts.tts_to_file(
text=text,
speaker_wav=audio_file,
language=language,
file_path=temp_audio_path
)
return temp_audio_path
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Advanced Voice Synthesis")
# Display voice files status
voice_status = check_voice_files()
gr.Markdown(voice_status)
with gr.Tabs():
with gr.TabItem("TTS"):
with gr.Row():
tts_text = gr.Textbox(label="Text to speak")
tts_voice = gr.Dropdown(choices=list(celebrity_voices.keys()), label="Celebrity Voice")
tts_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar","hi"], label="Language", value="en")
tts_generate_btn = gr.Button("Generate")
tts_output = gr.Audio(label="Generated Audio")
tts_generate_btn.click(
tts_generate,
inputs=[tts_text, tts_voice, tts_language],
outputs=tts_output
)
with gr.TabItem("Clone Voice"):
with gr.Row():
clone_text = gr.Textbox(label="Text to speak")
clone_audio = gr.Audio(label="Voice reference audio file", type="filepath")
clone_language = gr.Dropdown(["en", "es", "fr", "de", "it", "ar"], label="Language", value="en")
clone_generate_btn = gr.Button("Generate")
clone_output = gr.Audio(label="Generated Audio")
clone_generate_btn.click(
clone_voice,
inputs=[clone_text, clone_audio, clone_language],
outputs=clone_output
)
# Launch the interface
demo.launch()
# Clean up temporary files (this will run after the Gradio server is closed)
for file in os.listdir():
if file.endswith('.wav') and file.startswith('tmp'):
os.remove(file)