Spaces:
Runtime error
Runtime error
import nltk | |
import ssl | |
try: | |
_create_unverified_https_context = ssl._create_unverified_context | |
except AttributeError: | |
pass | |
else: | |
ssl._create_default_https_context = _create_unverified_https_context | |
nltk.download("cmudict") | |
import os | |
import json | |
import random | |
import gradio as gr | |
import numpy as np | |
import torch | |
import re_matching | |
import utils | |
from infer import infer, latest_version, get_net_g | |
import gradio as gr | |
from config import config | |
from tools.webui import reload_javascript, get_character_html | |
device = config.webui_config.device | |
if device == "mps": | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
def speak_fn( | |
text: str, | |
exceed_flag, | |
speaker="TalkFlower_CNzh", | |
sdp_ratio=0.2, # SDP/DP混合比 | |
noise_scale=0.6, # 感情 | |
noise_scale_w=0.6, # 音素长度 | |
length_scale=0.9, # 语速 | |
language="ZH", | |
interval_between_para=0.2, # 段间间隔 | |
interval_between_sent=1, # 句间间隔 | |
): | |
while text.find("\n\n") != -1: | |
text = text.replace("\n\n", "\n") | |
if len(text) > 100: | |
print(f"Too Long Text: {text}") | |
if exceed_flag: | |
text = "不要超过100字!" | |
audio_value = "./assets/audios/nomorethan100.wav" | |
else: | |
text = "这句太长了,憋坏我啦!" | |
audio_value = "./assets/audios/overlength.wav" | |
exceed_flag = not exceed_flag | |
else: | |
audio_list = [] | |
if len(text) > 42: | |
print(f"Long Text: {text}") | |
para_list = re_matching.cut_para(text) | |
for p in para_list: | |
audio_list_sent = [] | |
sent_list = re_matching.cut_sent(p) | |
for s in sent_list: | |
audio = infer( | |
s, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language, | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
) | |
audio_list_sent.append(audio) | |
silence = np.zeros((int)(44100 * interval_between_sent)) | |
audio_list_sent.append(silence) | |
if (interval_between_para - interval_between_sent) > 0: | |
silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent))) | |
audio_list_sent.append(silence) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一 | |
audio_list.append(audio16bit) | |
else: | |
print(f"Short Text: {text}") | |
silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) | |
with torch.no_grad(): | |
for piece in text.split("|"): | |
audio = infer( | |
piece, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language, | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) | |
audio_list.append(audio16bit) | |
audio_list.append(silence) # 将静音添加到列表中 | |
audio_concat = np.concatenate(audio_list) | |
audio_value = (hps.data.sampling_rate, audio_concat) | |
return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True) | |
def submit_lock_fn(): | |
return gr.update(interactive=False) | |
def init_fn(): | |
gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)") | |
gr.Info("Only support Chinese now. Trying to train a mutilingual model. 欢迎在 Community 中提建议~") | |
index = random.randint(1,7) | |
welcome_text = get_sentence("Welcome", index) | |
return gr.update(value=f"./assets/audios/Welcome{index}.wav", autoplay=False), get_character_html(welcome_text) | |
def get_sentence(category, index=-1): | |
if index == -1: | |
index = random.randint(1, len(full_lines[category])) | |
return full_lines[category][f"{index}"] | |
with open("./css/style.css", "r", encoding="utf-8") as f: | |
customCSS = f.read() | |
with open("./assets/lines.json", "r", encoding="utf-8") as f: | |
full_lines = json.load(f) | |
with gr.Blocks(css=customCSS) as demo: | |
exceed_flag = gr.State(value=False) | |
tmp_string = gr.Textbox(value="", visible=False) | |
character_area = gr.HTML(get_character_html("你好呀!"), elem_id="character_area") | |
with gr.Tab("Speak", elem_id="tab-speak"): | |
speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card input_text", elem_id="speak_input") | |
speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card") | |
example_category = gr.Examples(["夸夸你 | Praise", "游戏台词 | Scripts", "玩梗 | Meme"], inputs=[tmp_string], elem_id="examples") | |
with gr.Tab("Chat", elem_id="tab-chat"): | |
chat_input = gr.Textbox(lines=1, placeholder="Coming Soon...", label="Chat to Talking Flower:", elem_classes="wonder-card input_text", elem_id="chat_input", interactive=False) | |
chat_button = gr.Button("Chat!", elem_id="chat_button", elem_classes="main-button wonder-card") | |
with gr.Tab("Mimic", elem_id="tab-mimic"): | |
gr.Textbox(lines=1, placeholder="Coming Soon...", label="Choose sound to mimic:", elem_classes="wonder-card input_text", elem_id="mimic_input", interactive=False) | |
mimic_button = gr.Button("Mimic!", elem_id="mimic_button", elem_classes="main-button wonder-card") | |
audio_output = gr.Audio(label="输出音频", show_label=False, autoplay=True, elem_id="audio_output", elem_classes="wonder-card") | |
demo.load( | |
init_fn, | |
inputs=[], | |
outputs=[audio_output, character_area] | |
) | |
speak_input.submit(submit_lock_fn, show_progress=False).then( | |
speak_fn, | |
inputs=[speak_input, exceed_flag], | |
outputs=[audio_output, character_area, exceed_flag, speak_button], | |
) | |
speak_button.click(submit_lock_fn, show_progress=False).then( | |
speak_fn, | |
inputs=[speak_input, exceed_flag], | |
outputs=[audio_output, character_area, exceed_flag, speak_button], | |
) | |
if __name__ == "__main__": | |
hps = utils.get_hparams_from_file(config.webui_config.config_path) | |
version = hps.version if hasattr(hps, "version") else latest_version | |
net_g = get_net_g(model_path=config.webui_config.model, version=version, device=device, hps=hps) | |
reload_javascript() | |
demo.launch( | |
allowed_paths=["./assets", "./javascript", "./css"], | |
show_api=False, | |
inbrowser=True, | |
) | |