Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
from pathlib import Path | |
import gradio as gr | |
import random | |
import spaces | |
from style_bert_vits2.constants import ( | |
DEFAULT_LENGTH, | |
DEFAULT_LINE_SPLIT, | |
DEFAULT_NOISE, | |
DEFAULT_NOISEW, | |
DEFAULT_SPLIT_INTERVAL, | |
) | |
from style_bert_vits2.logging import logger | |
from style_bert_vits2.models.infer import InvalidToneError | |
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk | |
from style_bert_vits2.tts_model import TTSModelHolder | |
pyopenjtalk.initialize_worker() | |
example_file = "chupa_examples.txt" | |
initial_text = ( | |
"ちゅぱ、ちゅるる、ぢゅ、んく、れーれゅれろれろれろ、じゅぽぽぽぽぽ……ちゅううう!" | |
) | |
with open(example_file, "r", encoding="utf-8") as f: | |
examples = f.read().splitlines() | |
def get_random_text() -> str: | |
return random.choice(examples) | |
initial_md = """ | |
# チュパ音合成デモ | |
*NSFW注意*: このデモは、性的な音声を生成するためのものです。 | |
- 2024-07-11: chupa_3を追加 | |
- 2024-07-10: chupa_2を追加 | |
- 2024-07-07: initial ver | |
""" | |
def make_interactive(): | |
return gr.update(interactive=True, value="音声合成") | |
def make_non_interactive(): | |
return gr.update(interactive=False, value="音声合成(モデルをロードしてください)") | |
def gr_util(item): | |
if item == "プリセットから選ぶ": | |
return (gr.update(visible=True), gr.Audio(visible=False, value=None)) | |
else: | |
return (gr.update(visible=False), gr.update(visible=True)) | |
def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks: | |
def tts_fn( | |
model_name, | |
model_path, | |
text, | |
language, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
line_split, | |
split_interval, | |
speaker, | |
): | |
model_holder.get_model(model_name, model_path) | |
assert model_holder.current_model is not None | |
speaker_id = model_holder.current_model.spk2id[speaker] | |
start_time = datetime.datetime.now() | |
try: | |
sr, audio = model_holder.current_model.infer( | |
text=text, | |
language=language, | |
sdp_ratio=sdp_ratio, | |
noise=noise_scale, | |
noise_w=noise_scale_w, | |
length=length_scale, | |
line_split=line_split, | |
split_interval=split_interval, | |
speaker_id=speaker_id, | |
) | |
except InvalidToneError as e: | |
logger.error(f"Tone error: {e}") | |
return f"Error: アクセント指定が不正です:\n{e}", None | |
except ValueError as e: | |
logger.error(f"Value error: {e}") | |
return f"Error: {e}", None | |
end_time = datetime.datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
message = f"Success, time: {duration} seconds." | |
return message, (sr, audio) | |
def get_model_files(model_name: str): | |
return [str(f) for f in model_holder.model_files_dict[model_name]] | |
model_names = model_holder.model_names | |
if len(model_names) == 0: | |
logger.error( | |
f"モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。" | |
) | |
with gr.Blocks() as app: | |
gr.Markdown( | |
f"Error: モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。" | |
) | |
return app | |
initial_pth_files = get_model_files(model_names[0]) | |
model = model_holder.get_model(model_names[0], initial_pth_files[0]) | |
speakers = list(model.spk2id.keys()) | |
with gr.Blocks(theme="ParityError/Anime") as app: | |
gr.Markdown(initial_md) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
model_name = gr.Dropdown( | |
label="モデル一覧", | |
choices=model_names, | |
value=model_names[0], | |
) | |
model_path = gr.Dropdown( | |
label="モデルファイル", | |
choices=initial_pth_files, | |
value=initial_pth_files[0], | |
) | |
refresh_button = gr.Button("更新", scale=1, visible=False) | |
load_button = gr.Button("ロード", scale=1, variant="primary") | |
with gr.Row(): | |
text_input = gr.TextArea( | |
label="テキスト", value=initial_text, scale=3 | |
) | |
random_button = gr.Button("例から選ぶ 🎲", scale=1) | |
random_button.click(get_random_text, outputs=[text_input]) | |
with gr.Row(): | |
length_scale = gr.Slider( | |
minimum=0.1, | |
maximum=2, | |
value=DEFAULT_LENGTH, | |
step=0.1, | |
label="生成音声の長さ(Length)", | |
) | |
sdp_ratio = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=1, | |
step=0.1, | |
label="SDP Ratio", | |
) | |
line_split = gr.Checkbox( | |
label="改行で分けて生成(分けたほうが感情が乗ります)", | |
value=DEFAULT_LINE_SPLIT, | |
visible=False, | |
) | |
split_interval = gr.Slider( | |
minimum=0.0, | |
maximum=2, | |
value=DEFAULT_SPLIT_INTERVAL, | |
step=0.1, | |
label="改行ごとに挟む無音の長さ(秒)", | |
) | |
line_split.change( | |
lambda x: (gr.Slider(visible=x)), | |
inputs=[line_split], | |
outputs=[split_interval], | |
) | |
language = gr.Dropdown( | |
choices=["JP"], value="JP", label="Language", visible=False | |
) | |
speaker = gr.Dropdown(label="話者", choices=speakers, value=speakers[0]) | |
with gr.Accordion(label="詳細設定", open=True): | |
noise_scale = gr.Slider( | |
minimum=0.1, | |
maximum=2, | |
value=DEFAULT_NOISE, | |
step=0.1, | |
label="Noise", | |
) | |
noise_scale_w = gr.Slider( | |
minimum=0.1, | |
maximum=2, | |
value=DEFAULT_NOISEW, | |
step=0.1, | |
label="Noise_W", | |
) | |
with gr.Column(): | |
tts_button = gr.Button("音声合成", variant="primary") | |
text_output = gr.Textbox(label="情報") | |
audio_output = gr.Audio(label="結果") | |
tts_button.click( | |
tts_fn, | |
inputs=[ | |
model_name, | |
model_path, | |
text_input, | |
language, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
line_split, | |
split_interval, | |
speaker, | |
], | |
outputs=[text_output, audio_output], | |
) | |
model_name.change( | |
model_holder.update_model_files_for_gradio, | |
inputs=[model_name], | |
outputs=[model_path], | |
) | |
model_path.change(make_non_interactive, outputs=[tts_button]) | |
refresh_button.click( | |
model_holder.update_model_names_for_gradio, | |
outputs=[model_name, model_path, tts_button], | |
) | |
style = gr.Dropdown(label="スタイル", choices=[], visible=False) | |
load_button.click( | |
model_holder.get_model_for_gradio, | |
inputs=[model_name, model_path], | |
outputs=[style, tts_button, speaker], | |
) | |
return app | |
if __name__ == "__main__": | |
import torch | |
from style_bert_vits2.constants import Languages | |
from style_bert_vits2.nlp import bert_models | |
bert_models.load_model(Languages.JP) | |
bert_models.load_tokenizer(Languages.JP) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_holder = TTSModelHolder(Path("model_assets"), device) | |
app = create_inference_app(model_holder) | |
app.launch(inbrowser=True) | |