Spark-TTS-0.5B / cli /inference.py
mrfakename's picture
Upload 43 files
d93aca0 verified
# Copyright (c) 2025 SparkAudio
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import torch
import soundfile as sf
import logging
from datetime import datetime
from cli.SparkTTS import SparkTTS
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Run TTS inference.")
parser.add_argument(
"--model_dir",
type=str,
default="pretrained_models/Spark-TTS-0.5B",
help="Path to the model directory",
)
parser.add_argument(
"--save_dir",
type=str,
default="example/results",
help="Directory to save generated audio files",
)
parser.add_argument("--device", type=int, default=0, help="CUDA device number")
parser.add_argument(
"--text", type=str, required=True, help="Text for TTS generation"
)
parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio")
parser.add_argument(
"--prompt_speech_path",
type=str,
help="Path to the prompt audio file",
)
parser.add_argument("--gender", choices=["male", "female"])
parser.add_argument(
"--pitch", choices=["very_low", "low", "moderate", "high", "very_high"]
)
parser.add_argument(
"--speed", choices=["very_low", "low", "moderate", "high", "very_high"]
)
return parser.parse_args()
def run_tts(args):
"""Perform TTS inference and save the generated audio."""
logging.info(f"Using model from: {args.model_dir}")
logging.info(f"Saving audio to: {args.save_dir}")
# Ensure the save directory exists
os.makedirs(args.save_dir, exist_ok=True)
# Convert device argument to torch.device
device = torch.device(f"cuda:{args.device}")
# Initialize the model
model = SparkTTS(args.model_dir, device)
# Generate unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(args.save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
with torch.no_grad():
wav = model.inference(
args.text,
args.prompt_speech_path,
prompt_text=args.prompt_text,
gender=args.gender,
pitch=args.pitch,
speed=args.speed,
)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
args = parse_args()
run_tts(args)