Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from loguru import logger | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from rubyinserter import add_ruby | |
from transformers import AutoTokenizer | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
repo_id = "2121-8/japanese-parler-tts-mini-bate" | |
logger.info(f"Using device: {device}") | |
logger.info(f"Loading model from: {repo_id}") | |
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device) | |
logger.success("Model loaded successfully") | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
def parler_tts(prompt: str, description: str): | |
logger.info(f"Prompt: {prompt}") | |
logger.info(f"Description: {description}") | |
if len(prompt) > 150: | |
return "Text is too long. Please keep it under 150 characters.", None | |
if len(description) > 300: | |
return "Description is too long. Please keep it under 300 characters.", None | |
prompt = add_ruby(prompt) | |
logger.info(f"Prompt with ruby: {prompt}") | |
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) | |
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
generation = model.generate( | |
input_ids=input_ids, prompt_input_ids=prompt_input_ids | |
) | |
audio_arr = generation.cpu().numpy().squeeze() | |
return "Success", (model.config.sampling_rate, audio_arr) | |
md = """ | |
# Japanese Parler-TTS Mini (β版) デモ | |
第三者による [Japanese Parler-TTS Mini (β版)](https://huggingface.co/2121-8/japanese-parler-tts-mini-bate) の音声合成デモです。 | |
- 入力文章: 150文字以内の文章を入力してください。 | |
- 説明文章: 300文字以内の文章を入力してください。音声の特徴を説明する文章を入力します(多分)。 | |
""" | |
with gr.Blocks() as app: | |
gr.Markdown(md) | |
prompt = gr.Textbox(label="入力文章") | |
description = gr.Textbox( | |
label="説明文章", | |
value="A female speaker with a slightly high-pitched voice delivers her words at a moderate speed with a quite monotone tone in a confined environment, resulting in a quite clear audio recording.", | |
) | |
btn = gr.Button("生成") | |
info_text = gr.Textbox(label="情報") | |
audio = gr.Audio() | |
btn.click( | |
fn=parler_tts, | |
inputs=[prompt, description], | |
outputs=[info_text, audio], | |
) | |
app.launch() | |