# Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import soundfile as sf import logging import gradio as gr from datetime import datetime from cli.SparkTTS import SparkTTS from sparktts.utils.token_parser import LEVELS_MAP_UI import spaces from huggingface_hub import snapshot_download def initialize_model(model_dir=snapshot_download("SparkAudio/Spark-TTS-0.5B"), device=0): """Load the model once at the beginning.""" logging.info(f"Loading model from: {model_dir}") device = torch.device("cuda") model = SparkTTS(model_dir, device) return model @spaces.GPU def run_tts( text, model, prompt_text=None, prompt_speech=None, gender=None, pitch=None, speed=None, save_dir="example/results", ): """Perform TTS inference and save the generated audio.""" logging.info(f"Saving audio to: {save_dir}") if prompt_text is not None: prompt_text = None if len(prompt_text) <= 1 else prompt_text # Ensure the save directory exists os.makedirs(save_dir, exist_ok=True) # Generate unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d%H%M%S") save_path = os.path.join(save_dir, f"{timestamp}.wav") logging.info("Starting inference...") # Perform inference and save the output audio with torch.no_grad(): wav = model.inference( text, prompt_speech, prompt_text, gender, pitch, speed, ) sf.write(save_path, wav, samplerate=16000) logging.info(f"Audio saved at: {save_path}") return save_path, model # Return model along with audio path @spaces.GPU def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record): """Gradio interface for TTS with prompt speech input.""" # Determine prompt speech (from audio file or recording) prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record prompt_text = None if len(prompt_text) < 2 else prompt_text audio_output_path, model = run_tts( text, model, prompt_text=prompt_text, prompt_speech=prompt_speech ) return audio_output_path, model @spaces.GPU def voice_creation(text, model, gender, pitch, speed): """Gradio interface for TTS with control over voice attributes.""" pitch = LEVELS_MAP_UI[int(pitch)] speed = LEVELS_MAP_UI[int(speed)] audio_output_path, model = run_tts( text, model, gender=gender, pitch=pitch, speed=speed ) return audio_output_path, model def build_ui(model_dir, device=0): with gr.Blocks() as demo: # Initialize model model = initialize_model(model_dir, device="cuda") # Use HTML for centered title gr.HTML('