import gradio as gr from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed import soundfile as sf import torch import os os.system("bash install.sh") # Set the seed for reproducibility seed = 456 torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) if torch.backends.mps.is_available(): torch.backends.mps.manual_seed(seed) if torch.xpu.is_available(): torch.xpu.manual_seed(seed) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" if torch.backends.mps.is_available(): device = "mps" if torch.xpu.is_available(): device = "xpu" torch_dtype = torch.float16 if device != "cpu" else torch.float32 model = ParlerTTSForConditionalGeneration.from_pretrained("AkhilTolani/parler-tts-music-200000").to(device, dtype=torch_dtype) tokenizer = AutoTokenizer.from_pretrained("AkhilTolani/parler-tts-music-200000") def generate_audio(prompt, description): input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) # Define num_codebooks from the model configuration num_codebooks = model.decoder.config.num_codebooks # Set up generation arguments gen_kwargs = { "do_sample": False, "temperature": 1.0, "max_length": 2580, "min_new_tokens": num_codebooks + 1, } set_seed(seed) # Generate the output generation = model.generate( input_ids=input_ids, prompt_input_ids=prompt_input_ids, **gen_kwargs ).to(torch.float32) audio_arr = generation.cpu().numpy().squeeze() sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate) return "parler_tts_out.wav" interface = gr.Interface( fn=generate_audio, inputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Description")], outputs=gr.Audio(label="Generated Audio"), title="Parler TTS Audio Generation", description="Generate audio using the Parler TTS model. Provide a prompt and description to generate the corresponding audio." ) if __name__ == "__main__": interface.launch()