Spaces:
Runtime error
Runtime error
import gradio as gr | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer | |
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed | |
import soundfile as sf | |
import torch | |
import os | |
os.system("bash install.sh") | |
# Set the seed for reproducibility | |
seed = 456 | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
if torch.backends.mps.is_available(): | |
torch.backends.mps.manual_seed(seed) | |
if torch.xpu.is_available(): | |
torch.xpu.manual_seed(seed) | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
if torch.xpu.is_available(): | |
device = "xpu" | |
torch_dtype = torch.float16 if device != "cpu" else torch.float32 | |
model = ParlerTTSForConditionalGeneration.from_pretrained("AkhilTolani/parler-tts-music-200000").to(device, dtype=torch_dtype) | |
tokenizer = AutoTokenizer.from_pretrained("AkhilTolani/parler-tts-music-200000") | |
def generate_audio(prompt, description): | |
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) | |
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
# Define num_codebooks from the model configuration | |
num_codebooks = model.decoder.config.num_codebooks | |
# Set up generation arguments | |
gen_kwargs = { | |
"do_sample": False, | |
"temperature": 1.0, | |
"max_length": 2580, | |
"min_new_tokens": num_codebooks + 1, | |
} | |
set_seed(seed) | |
# Generate the output | |
generation = model.generate( | |
input_ids=input_ids, | |
prompt_input_ids=prompt_input_ids, | |
**gen_kwargs | |
).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate) | |
return "parler_tts_out.wav" | |
interface = gr.Interface( | |
fn=generate_audio, | |
inputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Description")], | |
outputs=gr.Audio(label="Generated Audio"), | |
title="Parler TTS Audio Generation", | |
description="Generate audio using the Parler TTS model. Provide a prompt and description to generate the corresponding audio." | |
) | |
if __name__ == "__main__": | |
interface.launch() |