Spaces:
Runtime error
Runtime error
# Copyright (c) 2025 SparkAudio | |
# 2025 Xinsheng Wang (w.xinshawn@gmail.com) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import argparse | |
import torch | |
import soundfile as sf | |
import logging | |
from datetime import datetime | |
from cli.SparkTTS import SparkTTS | |
def parse_args(): | |
"""Parse command-line arguments.""" | |
parser = argparse.ArgumentParser(description="Run TTS inference.") | |
parser.add_argument( | |
"--model_dir", | |
type=str, | |
default="pretrained_models/Spark-TTS-0.5B", | |
help="Path to the model directory", | |
) | |
parser.add_argument( | |
"--save_dir", | |
type=str, | |
default="example/results", | |
help="Directory to save generated audio files", | |
) | |
parser.add_argument("--device", type=int, default=0, help="CUDA device number") | |
parser.add_argument( | |
"--text", type=str, required=True, help="Text for TTS generation" | |
) | |
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio") | |
parser.add_argument( | |
"--prompt_speech_path", | |
type=str, | |
help="Path to the prompt audio file", | |
) | |
parser.add_argument("--gender", choices=["male", "female"]) | |
parser.add_argument( | |
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"] | |
) | |
parser.add_argument( | |
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"] | |
) | |
return parser.parse_args() | |
def run_tts(args): | |
"""Perform TTS inference and save the generated audio.""" | |
logging.info(f"Using model from: {args.model_dir}") | |
logging.info(f"Saving audio to: {args.save_dir}") | |
# Ensure the save directory exists | |
os.makedirs(args.save_dir, exist_ok=True) | |
# Convert device argument to torch.device | |
device = torch.device(f"cuda:{args.device}") | |
# Initialize the model | |
model = SparkTTS(args.model_dir, device) | |
# Generate unique filename using timestamp | |
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
save_path = os.path.join(args.save_dir, f"{timestamp}.wav") | |
logging.info("Starting inference...") | |
# Perform inference and save the output audio | |
with torch.no_grad(): | |
wav = model.inference( | |
args.text, | |
args.prompt_speech_path, | |
prompt_text=args.prompt_text, | |
gender=args.gender, | |
pitch=args.pitch, | |
speed=args.speed, | |
) | |
sf.write(save_path, wav, samplerate=16000) | |
logging.info(f"Audio saved at: {save_path}") | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
args = parse_args() | |
run_tts(args) | |