feat: change markdown and lable, and prepare to handle diff sample rate model and one speaker one language model
b128fb7
import json | |
import os | |
import tempfile | |
import gradio as gr | |
import TTS | |
from TTS.utils.synthesizer import Synthesizer | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
from omegaconf import OmegaConf | |
from ipa.ipa import get_ipa, parse_ipa | |
from replace.tts import ChangedVitsConfig | |
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig | |
def load_model(model_id): | |
model_dir = snapshot_download(model_id) | |
config_file_path = os.path.join(model_dir, "config.json") | |
model_ckpt_path = os.path.join(model_dir, "model.pth") | |
speaker_file_path = os.path.join(model_dir, "speakers.pth") | |
language_file_path = os.path.join(model_dir, "language_ids.json") | |
speaker_embedding_file_path = os.path.join(model_dir, "speaker_embs.pth") | |
temp_config_path = "temp_config.json" | |
with open(config_file_path, "r") as f: | |
content = f.read() | |
content = content.replace("speakers.pth", speaker_file_path) | |
content = content.replace("language_ids.json", language_file_path) | |
content = content.replace("speaker_embs.pth", speaker_embedding_file_path) | |
f.close() | |
with open(temp_config_path, "w") as f: | |
f.write(content) | |
f.close() | |
return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path) | |
OmegaConf.register_new_resolver("load_model", load_model) | |
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml")) | |
def text_to_speech( | |
model_id: str, | |
use_default_emb_or_custom: str, | |
speaker_wav, | |
speaker: str, | |
dialect, | |
text: str, | |
): | |
model = models_config[model_id]["model"] | |
if len(text) == 0: | |
raise gr.Error("請勿輸入空字串。") | |
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect) | |
if len(missing_words) > 0: | |
raise gr.Error( | |
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。" | |
) | |
if use_default_emb_or_custom == "default": | |
wav = model.tts( | |
parse_ipa(ipa), | |
speaker_name=speaker, | |
language_name=dialect, | |
split_sentences=False, | |
) | |
else: | |
wav = model.tts( | |
parse_ipa(ipa), | |
speaker_wav=speaker_wav, | |
language_name=dialect, | |
split_sentences=False, | |
) | |
return ( | |
words, | |
pinyin, | |
(model.tts_model.config.audio.sample_rate, np.array(wav)), | |
) | |
def when_model_selected(model_id): | |
model_config = models_config[model_id] | |
speaker_drop_down_choices = [] | |
if "speaker_mapping" in model_config: | |
speaker_drop_down_choices = [ | |
(k, v) for k, v in model_config["speaker_mapping"].items() | |
] | |
dialect_drop_down_choices = [ | |
(k, v) for k, v in model_config["dialect_mapping"].items() | |
] | |
use_default_emb_or_ref_radio_visible = False | |
if model_config["model"].tts_model.config.model_args.speaker_encoder_model_path: | |
use_default_emb_or_ref_radio_visible = True | |
return ( | |
gr.update( | |
choices=speaker_drop_down_choices, | |
value=speaker_drop_down_choices[0][1] if len(speaker_drop_down_choices) > 0 else None, | |
visible=len(speaker_drop_down_choices) > 1, | |
), | |
gr.update( | |
choices=dialect_drop_down_choices, | |
value=dialect_drop_down_choices[0][1], | |
visible=len(dialect_drop_down_choices) > 1, | |
), | |
gr.update(visible=use_default_emb_or_ref_radio_visible, value="default"), | |
) | |
def use_default_emb_or_custom_radio_input(use_default_emb_or_custom): | |
if use_default_emb_or_custom == "custom": | |
return gr.update(visible=True), gr.update(visible=False) | |
return gr.update(visible=False), gr.update(visible=True) | |
demo = gr.Blocks( | |
title="臺灣客語語音生成系統", | |
css="@import url(https://tauhu.tw/tauhu-oo.css);", | |
theme=gr.themes.Default( | |
font=( | |
"tauhu-oo", | |
gr.themes.GoogleFont("Source Sans Pro"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
) | |
), | |
) | |
with demo: | |
default_model_id = list(models_config.keys())[0] | |
model_drop_down = gr.Dropdown( | |
models_config.keys(), | |
value=default_model_id, | |
label="模型", | |
) | |
use_default_emb_or_custom_radio = gr.Radio( | |
label="use default speaker embedding or custom speaker embedding", | |
choices=["default", "custom"], | |
value="default", | |
visible=False, | |
) | |
speaker_wav = gr.Microphone( | |
label="speaker wav", | |
visible=False, | |
editable=False, | |
type="filepath", | |
waveform_options=gr.WaveformOptions( | |
show_controls=False, | |
sample_rate=16000, | |
), | |
) | |
speaker_drop_down = gr.Dropdown( | |
choices=[ | |
(k, v) | |
for k, v in models_config[default_model_id]["speaker_mapping"].items() | |
], | |
value=list(models_config[default_model_id]["speaker_mapping"].values())[0], | |
label="語者", | |
) | |
use_default_emb_or_custom_radio.input( | |
use_default_emb_or_custom_radio_input, | |
inputs=[use_default_emb_or_custom_radio], | |
outputs=[speaker_wav, speaker_drop_down], | |
) | |
dialect_drop_down = gr.Dropdown( | |
choices=[ | |
(k, v) | |
for k, v in models_config[default_model_id]["dialect_mapping"].items() | |
], | |
value=list(models_config[default_model_id]["dialect_mapping"].values())[0], | |
label="腔調", | |
) | |
model_drop_down.input( | |
when_model_selected, | |
inputs=[model_drop_down], | |
outputs=[speaker_drop_down, dialect_drop_down, use_default_emb_or_custom_radio], | |
) | |
gr.Markdown( | |
""" | |
# 臺灣客語語音合成系統 | |
### Taiwanese Hakka Text-to-Speech System | |
### 模型 | |
- **sixian-1p-240417**(四縣腔,單一語者) | |
### 研發 | |
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)(諾思資訊 North Co., Ltd.)** | |
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)(諾思資訊 North Co., Ltd.)** | |
""" | |
) | |
gr.Interface( | |
text_to_speech, | |
inputs=[ | |
model_drop_down, | |
use_default_emb_or_custom_radio, | |
speaker_wav, | |
speaker_drop_down, | |
dialect_drop_down, | |
gr.Textbox(label="輸入文字"), | |
], | |
outputs=[ | |
gr.Textbox(interactive=False, label="斷詞"), | |
gr.Textbox(interactive=False, label="客語拼音"), | |
gr.Audio(interactive=False, label="合成語音", show_download_button=True), | |
], | |
allow_flagging="auto", | |
) | |
demo.launch() | |