MZhaovo's picture
Upload folder using huggingface_hub
1379699
raw
history blame
7.06 kB
import nltk
import ssl
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
nltk.download("cmudict")
import os
import json
import random
import gradio as gr
import numpy as np
import torch
import re_matching
import utils
from infer import infer, latest_version, get_net_g
import gradio as gr
from config import config
from tools.webui import reload_javascript, get_character_html
device = config.webui_config.device
if device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def speak_fn(
text: str,
exceed_flag,
speaker="TalkFlower_CNzh",
sdp_ratio=0.2, # SDP/DP混合比
noise_scale=0.6, # 感情
noise_scale_w=0.6, # 音素长度
length_scale=0.9, # 语速
language="ZH",
interval_between_para=0.2, # 段间间隔
interval_between_sent=1, # 句间间隔
):
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
if len(text) > 100:
print(f"Too Long Text: {text}")
if exceed_flag:
text = "不要超过100字!"
return gr.update(value="./assets/audios/nomorethan100.wav"), get_character_html(text), False, gr.update(interactive=True)
else:
text = "这句太长了,憋坏我啦!"
return gr.update(value="./assets/audios/overlength.wav"), get_character_html(text), True, gr.update(interactive=True)
audio_list = []
if len(text) > 42:
print(f"Long Text: {text}")
para_list = re_matching.cut_para(text)
for p in para_list:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
for s in sent_list:
audio = infer(
s,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio_list_sent.append(audio)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent)))
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一
audio_list.append(audio16bit)
else:
print(f"Short Text: {text}")
silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
with torch.no_grad():
for piece in text.split("|"):
audio = infer(
piece,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
audio_list.append(silence) # 将静音添加到列表中
audio_concat = np.concatenate(audio_list)
return (hps.data.sampling_rate, audio_concat), get_character_html(text), exceed_flag, gr.update(interactive=True)
def submit_lock_fn():
return gr.update(interactive=False)
def init_fn():
gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)")
gr.Info("Only support Chinese now. Trying to train a mutilingual model. 欢迎在 Community 中提建议~")
index = random.randint(1,7)
welcome_text = get_sentence("Welcome", index)
return gr.update(value=f"./assets/audios/Welcome{index}.wav"), get_character_html(welcome_text)
def get_sentence(category, index=-1):
if index == -1:
index = random.randint(1, len(full_lines[category]))
return full_lines[category][f"{index}"]
with open("./css/style.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with open("./assets/lines.json", "r", encoding="utf-8") as f:
full_lines = json.load(f)
with gr.Blocks(css=customCSS) as demo:
exceed_flag = gr.State(value=False)
tmp_string = gr.Textbox(value="", visible=False)
character_area = gr.HTML(get_character_html("你好呀!"), elem_id="character_area")
with gr.Tab("Speak", elem_id="tab-speak"):
speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card", elem_id="input_text")
speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card")
example_category = gr.Examples(["夸夸你 | Praise", "游戏台词 | Scripts", "玩梗 | Meme"], fn=get_sentence, inputs=[tmp_string], outputs=[speak_input], run_on_click=True, elem_id="examples")
with gr.Tab("Chat", elem_id="tab-chat"):
chat_input = gr.Textbox(lines=1, placeholder="Coming Soon...", label="Chat to Talking Flower:", elem_classes="wonder-card", elem_id="input_text", interactive=False)
chat_button = gr.Button("Chat!", elem_id="chat_button", elem_classes="main-button wonder-card")
with gr.Tab("Mimic", elem_id="tab-mimic"):
gr.Textbox(lines=1, placeholder="Coming Soon...", label="Choose sound to mimic:", elem_classes="wonder-card", elem_id="input_text", interactive=False)
mimic_button = gr.Button("Mimic!", elem_id="mimic_button", elem_classes="main-button wonder-card")
audio_output = gr.Audio(label="输出音频", show_label=False, autoplay=True, elem_id="audio_output", elem_classes="wonder-card")
demo.load(
init_fn,
inputs=[],
outputs=[audio_output, character_area]
)
speak_input.submit(submit_lock_fn, show_progress=False).then(
speak_fn,
inputs=[speak_input, exceed_flag],
outputs=[audio_output, character_area, exceed_flag, speak_button],
)
speak_button.click(submit_lock_fn, show_progress=False).then(
speak_fn,
inputs=[speak_input, exceed_flag],
outputs=[audio_output, character_area, exceed_flag, speak_button],
)
if __name__ == "__main__":
hps = utils.get_hparams_from_file(config.webui_config.config_path)
version = hps.version if hasattr(hps, "version") else latest_version
net_g = get_net_g(model_path=config.webui_config.model, version=version, device=device, hps=hps)
reload_javascript()
demo.launch(
allowed_paths=["./assets"],
show_api=False,
inbrowser=True,
)