vocals / app.py
AkhilTolani's picture
Update app.py
be0bc58 verified
raw
history blame
2.25 kB
import gradio as gr
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
import soundfile as sf
import torch
import os
os.system("bash install.sh")
# Set the seed for reproducibility
seed = 456
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.backends.mps.manual_seed(seed)
if torch.xpu.is_available():
torch.xpu.manual_seed(seed)
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.backends.mps.is_available():
device = "mps"
if torch.xpu.is_available():
device = "xpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32
model = ParlerTTSForConditionalGeneration.from_pretrained("AkhilTolani/parler-tts-music-200000").to(device, dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained("AkhilTolani/parler-tts-music-200000")
def generate_audio(prompt, description):
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Define num_codebooks from the model configuration
num_codebooks = model.decoder.config.num_codebooks
# Set up generation arguments
gen_kwargs = {
"do_sample": False,
"temperature": 1.0,
"max_length": 2580,
"min_new_tokens": num_codebooks + 1,
}
set_seed(seed)
# Generate the output
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
**gen_kwargs
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
return "parler_tts_out.wav"
interface = gr.Interface(
fn=generate_audio,
inputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Description")],
outputs=gr.Audio(label="Generated Audio"),
title="Parler TTS Audio Generation",
description="Generate audio using the Parler TTS model. Provide a prompt and description to generate the corresponding audio."
)
if __name__ == "__main__":
interface.launch()