Spaces:
Runtime error
Runtime error
# Copyright (c) 2025 SparkAudio | |
# 2025 Xinsheng Wang (w.xinshawn@gmail.com) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
import soundfile as sf | |
import logging | |
import gradio as gr | |
from datetime import datetime | |
from cli.SparkTTS import SparkTTS | |
from sparktts.utils.token_parser import LEVELS_MAP_UI | |
import spaces | |
from huggingface_hub import snapshot_download | |
def initialize_model(model_dir=snapshot_download("SparkAudio/Spark-TTS-0.5B"), device=0): | |
"""Load the model once at the beginning.""" | |
logging.info(f"Loading model from: {model_dir}") | |
device = torch.device("cuda") | |
model = SparkTTS(model_dir, device) | |
return model | |
def run_tts( | |
text, | |
model, | |
prompt_text=None, | |
prompt_speech=None, | |
gender=None, | |
pitch=None, | |
speed=None, | |
save_dir="example/results", | |
): | |
"""Perform TTS inference and save the generated audio.""" | |
logging.info(f"Saving audio to: {save_dir}") | |
if prompt_text is not None: | |
prompt_text = None if len(prompt_text) <= 1 else prompt_text | |
# Ensure the save directory exists | |
os.makedirs(save_dir, exist_ok=True) | |
# Generate unique filename using timestamp | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
save_path = os.path.join(save_dir, f"{timestamp}.wav") | |
logging.info("Starting inference...") | |
# Perform inference and save the output audio | |
with torch.no_grad(): | |
wav = model.inference( | |
text, | |
prompt_speech, | |
prompt_text, | |
gender, | |
pitch, | |
speed, | |
) | |
sf.write(save_path, wav, samplerate=16000) | |
logging.info(f"Audio saved at: {save_path}") | |
return save_path, model # Return model along with audio path | |
def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record): | |
"""Gradio interface for TTS with prompt speech input.""" | |
# Determine prompt speech (from audio file or recording) | |
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record | |
prompt_text = None if len(prompt_text) < 2 else prompt_text | |
audio_output_path, model = run_tts( | |
text, model, prompt_text=prompt_text, prompt_speech=prompt_speech | |
) | |
return audio_output_path, model | |
def voice_creation(text, model, gender, pitch, speed): | |
"""Gradio interface for TTS with control over voice attributes.""" | |
pitch = LEVELS_MAP_UI[int(pitch)] | |
speed = LEVELS_MAP_UI[int(speed)] | |
audio_output_path, model = run_tts( | |
text, model, gender=gender, pitch=pitch, speed=speed | |
) | |
return audio_output_path, model | |
def build_ui(model_dir, device=0): | |
with gr.Blocks() as demo: | |
# Initialize model | |
model = initialize_model(model_dir, device="cuda") | |
# Use HTML for centered title | |
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>') | |
with gr.Tabs(): | |
# Voice Clone Tab | |
with gr.TabItem("Voice Clone"): | |
gr.Markdown( | |
"### Upload reference audio or recording (上传参考音频或者录音)" | |
) | |
with gr.Row(): | |
prompt_wav_upload = gr.Audio( | |
sources="upload", | |
type="filepath", | |
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", | |
) | |
prompt_wav_record = gr.Audio( | |
sources="microphone", | |
type="filepath", | |
label="Record the prompt audio file.", | |
) | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Text", lines=3, placeholder="Enter text here" | |
) | |
prompt_text_input = gr.Textbox( | |
label="Text of prompt speech (Optional; recommended for cloning in the same language.)", | |
lines=3, | |
placeholder="Enter text of the prompt speech.", | |
) | |
audio_output = gr.Audio( | |
label="Generated Audio", autoplay=True, streaming=True | |
) | |
generate_buttom_clone = gr.Button("Generate") | |
generate_buttom_clone.click( | |
voice_clone, | |
inputs=[ | |
text_input, | |
gr.State(model), | |
prompt_text_input, | |
prompt_wav_upload, | |
prompt_wav_record, | |
], | |
outputs=[audio_output, gr.State(model)], | |
) | |
# Voice Creation Tab | |
with gr.TabItem("Voice Creation"): | |
gr.Markdown( | |
"### Create your own voice based on the following parameters" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gender = gr.Radio( | |
choices=["male", "female"], value="male", label="Gender" | |
) | |
pitch = gr.Slider( | |
minimum=1, maximum=5, step=1, value=3, label="Pitch" | |
) | |
speed = gr.Slider( | |
minimum=1, maximum=5, step=1, value=3, label="Speed" | |
) | |
with gr.Column(): | |
text_input_creation = gr.Textbox( | |
label="Input Text", | |
lines=3, | |
placeholder="Enter text here", | |
value="You can generate a customized voice by adjusting parameters such as pitch and speed.", | |
) | |
create_button = gr.Button("Create Voice") | |
audio_output = gr.Audio( | |
label="Generated Audio", autoplay=True, streaming=True | |
) | |
create_button.click( | |
voice_creation, | |
inputs=[text_input_creation, gr.State(model), gender, pitch, speed], | |
outputs=[audio_output, gr.State(model)], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_ui(model_dir=snapshot_download("SparkAudio/Spark-TTS-0.5B"), device=0) | |
demo.queue().launch() |