ChatTTS-Story-Telling / tts_model.py
fcyai
11
0b2e3c9
import datetime
import json
import os
import re
import time
import numpy as np
import torch
from tqdm import tqdm
import ChatTTS
from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K
import spaces
def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
"""
Load ChatTTS model
:param source:
:param force_redownload:
:param local_path:
:return:
"""
print("Loading ChatTTS model...")
chat = ChatTTS.Chat()
chat.load_models(source=source, force_redownload=force_redownload, custom_path=local_path, compile=False)
return chat
def clear_cuda_cache():
"""
Clear CUDA cache
:return:
"""
torch.cuda.empty_cache()
def deterministic(seed=0):
"""
Set random seed for reproducibility
:param seed:
:return:
"""
# ref: https://github.com/Jackiexiao/ChatTTS-api-ui-docker/blob/main/api.py#L27
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@spaces.GPU
def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, roleid=None,
temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False,
skip_refine_text=False, speaker_type="seed", pt_file=None):
from utils import combine_audio, save_audio, batch_split
print(f'generate_audio_for_seed,0, speaker_type:{speaker_type}')
print(f"speaker_type: {speaker_type}")
if speaker_type == "seed":
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()
elif speaker_type == "role":
# 从 JSON 文件中读取数据
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
slct_idx_loaded = json.load(json_file)
# 将包含 Tensor 数据的部分转换回 Tensor 对象
for key in slct_idx_loaded:
tensor_list = slct_idx_loaded[key]["tensor"]
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
# temperature = 0.001
elif speaker_type == "pt":
print(pt_file)
rnd_spk_emb = torch.load(pt_file)
print(rnd_spk_emb.shape)
if rnd_spk_emb.shape != (768,):
raise ValueError("维度应为 768。")
else:
raise ValueError(f"Invalid speaker_type: {speaker_type}. ")
print(f'generate_audio_for_seed,1, speaker_type:{speaker_type}')
params_infer_code = {
'spk_emb': rnd_spk_emb,
'prompt': f'[speed_{speed}]',
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
params_refine_text = {
'prompt': refine_text_prompt,
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
all_wavs = []
start_time = time.time()
total = len(texts)
flag = 0
if not cur_tqdm:
cur_tqdm = tqdm
print(f'generate_audio_for_seed,2, speaker_type:{speaker_type}')
if re.search(r'\[uv_break\]|\[laugh\]', ''.join(texts)) is not None:
if not skip_refine_text:
print("Detected [uv_break] or [laugh] in text, skipping refine_text")
skip_refine_text = True
for batch in cur_tqdm(batch_split(texts, batch_size), desc=f"Inferring audio for seed={seed}"):
flag += len(batch)
_params_infer_code = {**params_infer_code}
wavs = chat.infer(batch, params_infer_code=_params_infer_code, params_refine_text=params_refine_text,
use_decoder=True, skip_refine_text=skip_refine_text)
all_wavs.extend(wavs)
clear_cuda_cache()
print(f'generate_audio_for_seed,3, speaker_type:{speaker_type}')
if skip_save:
return all_wavs
combined_audio = combine_audio(all_wavs)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Saving audio for seed {seed}, took {elapsed_time:.2f}s")
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
wav_filename = f"chattts-[seed_{seed}][speed_{speed}]{refine_text_prompt}[{timestamp}].wav"
return save_audio(wav_filename, combined_audio)
def generate_refine_text(chat, seed, text, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
params_refine_text = {
'prompt': refine_text_prompt,
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
print('params_refine_text:', text)
print('refine_text_prompt:', refine_text_prompt)
refine_text = chat.infer(text, params_refine_text=params_refine_text, refine_text_only=True, skip_refine_text=False)
print('refine_text:', refine_text)
return refine_text
def tts(chat, text_file, seed, speed, oral, laugh, bk, seg, batch, progres=None):
"""
Text-to-Speech
:param chat: ChatTTS model
:param text_file: Text file or string
:param seed: Seed
:param speed: Speed
:param oral: Oral
:param laugh: Laugh
:param bk:
:param seg:
:param batch:
:param progres:
:return:
"""
from utils import read_long_text, split_text
if os.path.isfile(text_file):
content = read_long_text(text_file)
elif isinstance(text_file, str):
content = text_file
texts = split_text(content, min_length=seg)
print(texts)
# exit()
if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
return generate_audio_for_seed(chat, seed, texts, batch, speed, refine_text_prompt)