# Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import torch from typing import Tuple from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM from sparktts.utils.file import load_config from sparktts.models.audio_tokenizer import BiCodecTokenizer from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP class SparkTTS: """ Spark-TTS for text-to-speech generation. """ def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): """ Initializes the SparkTTS model with the provided configurations and device. Args: model_dir (Path): Directory containing the model and config files. device (torch.device): The device (CPU/GPU) to run the model on. """ self.device = device self.model_dir = model_dir self.configs = load_config(f"{model_dir}/config.yaml") self.sample_rate = self.configs["sample_rate"] self._initialize_inference() def _initialize_inference(self): """Initializes the tokenizer, model, and audio tokenizer for inference.""" self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) self.model.to(self.device) def process_prompt( self, text: str, prompt_speech_path: Path, prompt_text: str = None, ) -> Tuple[str, torch.Tensor]: """ Process input for voice cloning. Args: text (str): The text input to be converted to speech. prompt_speech_path (Path): Path to the audio file used as a prompt. prompt_text (str, optional): Transcript of the prompt audio. Return: Tuple[str, torch.Tensor]: Input prompt; global tokens """ global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( prompt_speech_path ) global_tokens = "".join( [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] ) # Prepare the input tokens for the model if prompt_text is not None: semantic_tokens = "".join( [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] ) inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens, ] else: inputs = [ TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", ] inputs = "".join(inputs) return inputs, global_token_ids def process_prompt_control( self, gender: str, pitch: str, speed: str, text: str, ): """ Process input for voice creation. Args: gender (str): female | male. pitch (str): very_low | low | moderate | high | very_high speed (str): very_low | low | moderate | high | very_high text (str): The text input to be converted to speech. Return: str: Input prompt """ assert gender in GENDER_MAP.keys() assert pitch in LEVELS_MAP.keys() assert speed in LEVELS_MAP.keys() gender_id = GENDER_MAP[gender] pitch_level_id = LEVELS_MAP[pitch] speed_level_id = LEVELS_MAP[speed] pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" speed_label_tokens = f"<|speed_label_{speed_level_id}|>" gender_tokens = f"<|gender_{gender_id}|>" attribte_tokens = "".join( [gender_tokens, pitch_label_tokens, speed_label_tokens] ) control_tts_inputs = [ TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>", "<|start_style_label|>", attribte_tokens, "<|end_style_label|>", ] return "".join(control_tts_inputs) @torch.no_grad() def inference( self, text: str, prompt_speech_path: Path = None, prompt_text: str = None, gender: str = None, pitch: str = None, speed: str = None, temperature: float = 0.8, top_k: float = 50, top_p: float = 0.95, ) -> torch.Tensor: """ Performs inference to generate speech from text, incorporating prompt audio and/or text. Args: text (str): The text input to be converted to speech. prompt_speech_path (Path): Path to the audio file used as a prompt. prompt_text (str, optional): Transcript of the prompt audio. gender (str): female | male. pitch (str): very_low | low | moderate | high | very_high speed (str): very_low | low | moderate | high | very_high temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. top_k (float, optional): Top-k sampling parameter. Default is 50. top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. Returns: torch.Tensor: Generated waveform as a tensor. """ if gender is not None: prompt = self.process_prompt_control(gender, pitch, speed, text) else: prompt, global_token_ids = self.process_prompt( text, prompt_speech_path, prompt_text ) model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) # Generate speech using the model generated_ids = self.model.generate( **model_inputs, max_new_tokens=3000, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, ) # Trim the output tokens to remove the input tokens generated_ids = [ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] # Decode the generated tokens into text predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # Extract semantic token IDs from the generated text pred_semantic_ids = ( torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) .long() .unsqueeze(0) ) if gender is not None: global_token_ids = ( torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) .long() .unsqueeze(0) .unsqueeze(0) ) # Convert semantic tokens back to waveform wav = self.audio_tokenizer.detokenize( global_token_ids.to(self.device).squeeze(0), pred_semantic_ids.to(self.device), ) return wav