AI_TalkingFlower / presets.py
MZhao-LEGION
support change Voices
657d73c
raw
history blame
8.32 kB
import os, logging, datetime, json, random
import gradio as gr
import numpy as np
import torch
import re_matching
import utils
from infer import infer, latest_version, get_net_g, infer_multilang
import gradio as gr
from config import config
from tools.webui import reload_javascript, get_character_html
from tools.sentence import split_by_language
logging.basicConfig(
level=logging.INFO,
format='[%(levelname)s|%(asctime)s]%(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
device = config.webui_config.device
if device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
hps = utils.get_hparams_from_file(config.webui_config.config_path)
version = hps.version if hasattr(hps, "version") else latest_version
net_g = get_net_g(model_path=config.webui_config.model, version=version, device=device, hps=hps)
with open("./css/style.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with open("./assets/lines.json", "r", encoding="utf-8") as f:
full_lines = json.load(f)
def speak_fn(
text: str,
exceed_flag,
speaker="TalkFlower_CNzh",
sdp_ratio=0.2, # SDP/DP混合比
noise_scale=0.6, # 感情
noise_scale_w=0.6, # 音素长度
length_scale=0.9, # 语速
language="ZH",
reference_audio=None,
emotion=4,
interval_between_para=0.2, # 段间间隔
interval_between_sent=1, # 句间间隔
):
if speaker == "Chinese": speaker = "TalkFlower_CNzh"
elif speaker == "English": speaker = "TalkFlower_USen"
elif speaker == "Japanese": speaker = "TalkFlower_JPja"
else: speaker = "TalkFlower_CNzh"
audio_list = []
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
if len(text) > 100:
logging.info(f"Too Long Text: {text}")
if exceed_flag:
text = "不要超过100字!"
audio_value = "./assets/audios/nomorethan100.wav"
else:
text = "这句太长了,憋坏我啦!"
audio_value = "./assets/audios/overlength.wav"
exceed_flag = not exceed_flag
else:
for idx, slice in enumerate(text.split("|")):
if slice == "":
continue
skip_start = idx != 0
skip_end = idx != len(text.split("|")) - 1
sentences_list = split_by_language(
slice, target_languages=["zh", "ja", "en"]
)
idx = 0
while idx < len(sentences_list):
text_to_generate = []
lang_to_generate = []
while True:
content, lang = sentences_list[idx]
temp_text = [content]
lang = lang.upper()
if lang == "JA":
lang = "JP"
if len(text_to_generate) > 0:
text_to_generate[-1] += [temp_text.pop(0)]
lang_to_generate[-1] += [lang]
if len(temp_text) > 0:
text_to_generate += [[i] for i in temp_text]
lang_to_generate += [[lang]] * len(temp_text)
if idx + 1 < len(sentences_list):
idx += 1
else:
break
skip_start = (idx != 0) and skip_start
skip_end = (idx != len(sentences_list) - 1) and skip_end
logging.info(f"{speaker[-4:]}: {text_to_generate}{lang_to_generate}")
with torch.no_grad():
for i, piece in enumerate(text_to_generate):
skip_start = (i != 0) and skip_start
skip_end = (i != len(text_to_generate) - 1) and skip_end
audio = infer_multilang(
piece,
reference_audio=reference_audio,
emotion=emotion,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=lang_to_generate[i],
hps=hps,
net_g=net_g,
device=device,
skip_start=skip_start,
skip_end=skip_end,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
idx += 1
# 单一语言推理
# if len(text) > 42:
# logging.info(f"Long Text: {text}")
# para_list = re_matching.cut_para(text)
# for p in para_list:
# audio_list_sent = []
# sent_list = re_matching.cut_sent(p)
# for s in sent_list:
# audio = infer(
# s,
# sdp_ratio=sdp_ratio,
# noise_scale=noise_scale,
# noise_scale_w=noise_scale_w,
# length_scale=length_scale,
# sid=speaker,
# language=language,
# hps=hps,
# net_g=net_g,
# device=device,
# reference_audio=reference_audio,
# emotion=emotion,
# )
# audio_list_sent.append(audio)
# silence = np.zeros((int)(44100 * interval_between_sent))
# audio_list_sent.append(silence)
# if (interval_between_para - interval_between_sent) > 0:
# silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent)))
# audio_list_sent.append(silence)
# audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一
# audio_list.append(audio16bit)
# else:
# logging.info(f"Short Text: {text}")
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
# with torch.no_grad():
# for piece in text.split("|"):
# audio = infer(
# piece,
# sdp_ratio=sdp_ratio,
# noise_scale=noise_scale,
# noise_scale_w=noise_scale_w,
# length_scale=length_scale,
# sid=speaker,
# language=language,
# hps=hps,
# net_g=net_g,
# device=device,
# reference_audio=reference_audio,
# emotion=emotion,
# )
# audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
# audio_list.append(audio16bit)
# audio_list.append(silence) # 将静音添加到列表中
audio_concat = np.concatenate(audio_list)
audio_value = (hps.data.sampling_rate, audio_concat)
return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True)
def submit_lock_fn():
return gr.update(interactive=False)
def init_fn():
gr.Info("2023-11-28: 支持多语言(中、英、日)!支持更换音色!")
# gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)")
# gr.Info("Support languages: Chinese, English, Japanese. 欢迎在 Community 中提建议~")
index = random.randint(1,7)
welcome_text = get_sentence("Welcome", index)
return get_character_html(welcome_text) #gr.update(value=f"./assets/audios/Welcome{index}.wav", autoplay=False),
def get_sentence(category, index=-1):
if index == -1:
index = random.randint(1, len(full_lines[category]))
return full_lines[category][f"{index}"]