sbv2_chupa_demo / app.py
litagin's picture
Update app.py
51d6dd0 verified
import datetime
from pathlib import Path
import gradio as gr
import random
import spaces
from style_bert_vits2.constants import (
DEFAULT_LENGTH,
DEFAULT_LINE_SPLIT,
DEFAULT_NOISE,
DEFAULT_NOISEW,
DEFAULT_SPLIT_INTERVAL,
)
from style_bert_vits2.logging import logger
from style_bert_vits2.models.infer import InvalidToneError
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk
from style_bert_vits2.tts_model import TTSModelHolder
pyopenjtalk.initialize_worker()
example_file = "chupa_examples.txt"
initial_text = (
"ちゅぱ、ちゅるる、ぢゅ、んく、れーれゅれろれろれろ、じゅぽぽぽぽぽ……ちゅううう!"
)
with open(example_file, "r", encoding="utf-8") as f:
examples = f.read().splitlines()
def get_random_text() -> str:
return random.choice(examples)
initial_md = """
# チュパ音合成デモ
*NSFW注意*: このデモは、性的な音声を生成するためのものです。
- 2024-07-11: chupa_3を追加
- 2024-07-10: chupa_2を追加
- 2024-07-07: initial ver
"""
def make_interactive():
return gr.update(interactive=True, value="音声合成")
def make_non_interactive():
return gr.update(interactive=False, value="音声合成(モデルをロードしてください)")
def gr_util(item):
if item == "プリセットから選ぶ":
return (gr.update(visible=True), gr.Audio(visible=False, value=None))
else:
return (gr.update(visible=False), gr.update(visible=True))
def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks:
@spaces.GPU
def tts_fn(
model_name,
model_path,
text,
language,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
line_split,
split_interval,
speaker,
):
model_holder.get_model(model_name, model_path)
assert model_holder.current_model is not None
speaker_id = model_holder.current_model.spk2id[speaker]
start_time = datetime.datetime.now()
try:
sr, audio = model_holder.current_model.infer(
text=text,
language=language,
sdp_ratio=sdp_ratio,
noise=noise_scale,
noise_w=noise_scale_w,
length=length_scale,
line_split=line_split,
split_interval=split_interval,
speaker_id=speaker_id,
)
except InvalidToneError as e:
logger.error(f"Tone error: {e}")
return f"Error: アクセント指定が不正です:\n{e}", None
except ValueError as e:
logger.error(f"Value error: {e}")
return f"Error: {e}", None
end_time = datetime.datetime.now()
duration = (end_time - start_time).total_seconds()
message = f"Success, time: {duration} seconds."
return message, (sr, audio)
def get_model_files(model_name: str):
return [str(f) for f in model_holder.model_files_dict[model_name]]
model_names = model_holder.model_names
if len(model_names) == 0:
logger.error(
f"モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
)
with gr.Blocks() as app:
gr.Markdown(
f"Error: モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
)
return app
initial_pth_files = get_model_files(model_names[0])
model = model_holder.get_model(model_names[0], initial_pth_files[0])
speakers = list(model.spk2id.keys())
with gr.Blocks(theme="ParityError/Anime") as app:
gr.Markdown(initial_md)
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column(scale=3):
model_name = gr.Dropdown(
label="モデル一覧",
choices=model_names,
value=model_names[0],
)
model_path = gr.Dropdown(
label="モデルファイル",
choices=initial_pth_files,
value=initial_pth_files[0],
)
refresh_button = gr.Button("更新", scale=1, visible=False)
load_button = gr.Button("ロード", scale=1, variant="primary")
with gr.Row():
text_input = gr.TextArea(
label="テキスト", value=initial_text, scale=3
)
random_button = gr.Button("例から選ぶ 🎲", scale=1)
random_button.click(get_random_text, outputs=[text_input])
with gr.Row():
length_scale = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_LENGTH,
step=0.1,
label="生成音声の長さ(Length)",
)
sdp_ratio = gr.Slider(
minimum=0,
maximum=1,
value=1,
step=0.1,
label="SDP Ratio",
)
line_split = gr.Checkbox(
label="改行で分けて生成(分けたほうが感情が乗ります)",
value=DEFAULT_LINE_SPLIT,
visible=False,
)
split_interval = gr.Slider(
minimum=0.0,
maximum=2,
value=DEFAULT_SPLIT_INTERVAL,
step=0.1,
label="改行ごとに挟む無音の長さ(秒)",
)
line_split.change(
lambda x: (gr.Slider(visible=x)),
inputs=[line_split],
outputs=[split_interval],
)
language = gr.Dropdown(
choices=["JP"], value="JP", label="Language", visible=False
)
speaker = gr.Dropdown(label="話者", choices=speakers, value=speakers[0])
with gr.Accordion(label="詳細設定", open=True):
noise_scale = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_NOISE,
step=0.1,
label="Noise",
)
noise_scale_w = gr.Slider(
minimum=0.1,
maximum=2,
value=DEFAULT_NOISEW,
step=0.1,
label="Noise_W",
)
with gr.Column():
tts_button = gr.Button("音声合成", variant="primary")
text_output = gr.Textbox(label="情報")
audio_output = gr.Audio(label="結果")
tts_button.click(
tts_fn,
inputs=[
model_name,
model_path,
text_input,
language,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
line_split,
split_interval,
speaker,
],
outputs=[text_output, audio_output],
)
model_name.change(
model_holder.update_model_files_for_gradio,
inputs=[model_name],
outputs=[model_path],
)
model_path.change(make_non_interactive, outputs=[tts_button])
refresh_button.click(
model_holder.update_model_names_for_gradio,
outputs=[model_name, model_path, tts_button],
)
style = gr.Dropdown(label="スタイル", choices=[], visible=False)
load_button.click(
model_holder.get_model_for_gradio,
inputs=[model_name, model_path],
outputs=[style, tts_button, speaker],
)
return app
if __name__ == "__main__":
import torch
from style_bert_vits2.constants import Languages
from style_bert_vits2.nlp import bert_models
bert_models.load_model(Languages.JP)
bert_models.load_tokenizer(Languages.JP)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_holder = TTSModelHolder(Path("model_assets"), device)
app = create_inference_app(model_holder)
app.launch(inbrowser=True)