txya900619's picture
feat: change markdown and lable, and prepare to handle diff sample rate model and one speaker one language model
b128fb7
raw
history blame
6.76 kB
import json
import os
import tempfile
import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig
def load_model(model_id):
model_dir = snapshot_download(model_id)
config_file_path = os.path.join(model_dir, "config.json")
model_ckpt_path = os.path.join(model_dir, "model.pth")
speaker_file_path = os.path.join(model_dir, "speakers.pth")
language_file_path = os.path.join(model_dir, "language_ids.json")
speaker_embedding_file_path = os.path.join(model_dir, "speaker_embs.pth")
temp_config_path = "temp_config.json"
with open(config_file_path, "r") as f:
content = f.read()
content = content.replace("speakers.pth", speaker_file_path)
content = content.replace("language_ids.json", language_file_path)
content = content.replace("speaker_embs.pth", speaker_embedding_file_path)
f.close()
with open(temp_config_path, "w") as f:
f.write(content)
f.close()
return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path)
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
def text_to_speech(
model_id: str,
use_default_emb_or_custom: str,
speaker_wav,
speaker: str,
dialect,
text: str,
):
model = models_config[model_id]["model"]
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
if use_default_emb_or_custom == "default":
wav = model.tts(
parse_ipa(ipa),
speaker_name=speaker,
language_name=dialect,
split_sentences=False,
)
else:
wav = model.tts(
parse_ipa(ipa),
speaker_wav=speaker_wav,
language_name=dialect,
split_sentences=False,
)
return (
words,
pinyin,
(model.tts_model.config.audio.sample_rate, np.array(wav)),
)
def when_model_selected(model_id):
model_config = models_config[model_id]
speaker_drop_down_choices = []
if "speaker_mapping" in model_config:
speaker_drop_down_choices = [
(k, v) for k, v in model_config["speaker_mapping"].items()
]
dialect_drop_down_choices = [
(k, v) for k, v in model_config["dialect_mapping"].items()
]
use_default_emb_or_ref_radio_visible = False
if model_config["model"].tts_model.config.model_args.speaker_encoder_model_path:
use_default_emb_or_ref_radio_visible = True
return (
gr.update(
choices=speaker_drop_down_choices,
value=speaker_drop_down_choices[0][1] if len(speaker_drop_down_choices) > 0 else None,
visible=len(speaker_drop_down_choices) > 1,
),
gr.update(
choices=dialect_drop_down_choices,
value=dialect_drop_down_choices[0][1],
visible=len(dialect_drop_down_choices) > 1,
),
gr.update(visible=use_default_emb_or_ref_radio_visible, value="default"),
)
def use_default_emb_or_custom_radio_input(use_default_emb_or_custom):
if use_default_emb_or_custom == "custom":
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
use_default_emb_or_custom_radio = gr.Radio(
label="use default speaker embedding or custom speaker embedding",
choices=["default", "custom"],
value="default",
visible=False,
)
speaker_wav = gr.Microphone(
label="speaker wav",
visible=False,
editable=False,
type="filepath",
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=16000,
),
)
speaker_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["speaker_mapping"].items()
],
value=list(models_config[default_model_id]["speaker_mapping"].values())[0],
label="語者",
)
use_default_emb_or_custom_radio.input(
use_default_emb_or_custom_radio_input,
inputs=[use_default_emb_or_custom_radio],
outputs=[speaker_wav, speaker_drop_down],
)
dialect_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_drop_down, use_default_emb_or_custom_radio],
)
gr.Markdown(
"""
# 臺灣客語語音合成系統
### Taiwanese Hakka Text-to-Speech System
### 模型
- **sixian-1p-240417**(四縣腔,單一語者)
### 研發
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)(諾思資訊 North Co., Ltd.)**
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)(諾思資訊 North Co., Ltd.)**
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
use_default_emb_or_custom_radio,
speaker_wav,
speaker_drop_down,
dialect_drop_down,
gr.Textbox(label="輸入文字"),
],
outputs=[
gr.Textbox(interactive=False, label="斷詞"),
gr.Textbox(interactive=False, label="客語拼音"),
gr.Audio(interactive=False, label="合成語音", show_download_button=True),
],
allow_flagging="auto",
)
demo.launch()