Spaces:
Runtime error
Runtime error
AkhilTolani
commited on
Commit
•
5cb8d27
1
Parent(s):
29af7a4
Update app.py
Browse files
app.py
CHANGED
@@ -1,43 +1,42 @@
|
|
1 |
import gradio as gr
|
2 |
-
from parler_tts import ParlerTTSForConditionalGeneration
|
3 |
from transformers import AutoTokenizer, set_seed
|
4 |
import soundfile as sf
|
5 |
import torch
|
6 |
import os
|
|
|
|
|
7 |
|
8 |
os.system("bash install.sh")
|
9 |
|
10 |
-
device
|
11 |
-
|
12 |
-
|
13 |
-
if
|
14 |
-
|
15 |
-
if torch.xpu.is_available():
|
16 |
-
device = "xpu"
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
|
22 |
|
23 |
def generate_audio(prompt, description, seed, temperature, max_length, do_sample):
|
24 |
-
# Set the seed for reproducibility
|
25 |
seed = int(seed)
|
26 |
-
|
27 |
-
if torch.cuda.is_available():
|
28 |
-
torch.cuda.manual_seed_all(seed)
|
29 |
-
if torch.backends.mps.is_available():
|
30 |
-
torch.backends.mps.manual_seed(seed)
|
31 |
-
if torch.xpu.is_available():
|
32 |
-
torch.xpu.manual_seed(seed)
|
33 |
|
34 |
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
|
35 |
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
36 |
|
37 |
-
# Define num_codebooks from the model configuration
|
38 |
num_codebooks = model.decoder.config.num_codebooks
|
39 |
|
40 |
-
# Set up generation arguments
|
41 |
gen_kwargs = {
|
42 |
"do_sample": do_sample,
|
43 |
"temperature": temperature,
|
@@ -45,22 +44,53 @@ def generate_audio(prompt, description, seed, temperature, max_length, do_sample
|
|
45 |
"min_new_tokens": num_codebooks + 1,
|
46 |
}
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
input_ids=input_ids,
|
53 |
-
prompt_input_ids=prompt_input_ids,
|
54 |
-
**gen_kwargs
|
55 |
-
).to(torch.float32)
|
56 |
|
57 |
-
|
|
|
58 |
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
|
59 |
|
60 |
return "parler_tts_out.wav"
|
61 |
|
62 |
-
|
63 |
-
|
|
|
64 |
default_seed = "456"
|
65 |
|
66 |
interface = gr.Interface(
|
@@ -69,14 +99,14 @@ interface = gr.Interface(
|
|
69 |
gr.Textbox(label="Prompt", value=default_prompt),
|
70 |
gr.Textbox(label="Description", value=default_description),
|
71 |
gr.Textbox(label="Seed", value=default_seed),
|
72 |
-
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0
|
73 |
gr.Slider(label="Max Length", minimum=256, maximum=5120, step=256, value=2580),
|
74 |
gr.Dropdown(label="Do Sample", choices=[True, False], value=True)
|
75 |
],
|
76 |
outputs=gr.Audio(label="Generated Audio"),
|
77 |
title="Parler TTS Audio Generation",
|
78 |
-
description="Generate
|
79 |
)
|
80 |
|
81 |
if __name__ == "__main__":
|
82 |
-
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig
|
3 |
from transformers import AutoTokenizer, set_seed
|
4 |
import soundfile as sf
|
5 |
import torch
|
6 |
import os
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.utils import set_seed, AutocastKwargs
|
9 |
|
10 |
os.system("bash install.sh")
|
11 |
|
12 |
+
# Setup device and accelerator
|
13 |
+
accelerator = Accelerator()
|
14 |
+
device = accelerator.device
|
15 |
+
mixed_precision = "no" if device == "cpu" else "bf16"
|
16 |
+
torch_dtype = torch.float32 if device == "cpu" else torch.bfloat16
|
|
|
|
|
17 |
|
18 |
+
# Load model and tokenizer
|
19 |
+
model_path = "AkhilTolani/parler-tts-finetune-vocals-only-large-18720-steps"
|
20 |
+
config = ParlerTTSConfig.from_pretrained(model_path)
|
21 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(
|
22 |
+
model_path,
|
23 |
+
config=config,
|
24 |
+
torch_dtype=torch_dtype,
|
25 |
+
attn_implementation="sdpa"
|
26 |
+
)
|
27 |
+
model = accelerator.prepare(model)
|
28 |
|
29 |
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
|
30 |
|
31 |
def generate_audio(prompt, description, seed, temperature, max_length, do_sample):
|
|
|
32 |
seed = int(seed)
|
33 |
+
set_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
|
36 |
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
37 |
|
|
|
38 |
num_codebooks = model.decoder.config.num_codebooks
|
39 |
|
|
|
40 |
gen_kwargs = {
|
41 |
"do_sample": do_sample,
|
42 |
"temperature": temperature,
|
|
|
44 |
"min_new_tokens": num_codebooks + 1,
|
45 |
}
|
46 |
|
47 |
+
# Prepare for generation
|
48 |
+
batch = {
|
49 |
+
"input_ids": input_ids,
|
50 |
+
"prompt_input_ids": prompt_input_ids,
|
51 |
+
}
|
52 |
+
|
53 |
+
autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "no"))
|
54 |
+
|
55 |
+
# Generation step
|
56 |
+
def generate_step(batch, accelerator):
|
57 |
+
batch.pop("decoder_attention_mask", None)
|
58 |
+
eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
|
59 |
+
|
60 |
+
with accelerator.autocast(**autocast_kwargs):
|
61 |
+
if mixed_precision == "fp16":
|
62 |
+
encoder_outputs = eval_model.text_encoder(
|
63 |
+
input_ids=batch.get("input_ids"),
|
64 |
+
attention_mask=batch.get("attention_mask", None)
|
65 |
+
)
|
66 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
67 |
+
if (
|
68 |
+
config.text_encoder.hidden_size != config.decoder.hidden_size
|
69 |
+
and config.decoder.cross_attention_hidden_size is None
|
70 |
+
):
|
71 |
+
encoder_hidden_states = eval_model.enc_to_dec_proj(encoder_hidden_states)
|
72 |
+
|
73 |
+
if batch.get("attention_mask", None) is not None:
|
74 |
+
encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
|
75 |
+
|
76 |
+
encoder_outputs.last_hidden_state = encoder_hidden_states
|
77 |
+
batch["encoder_outputs"] = encoder_outputs
|
78 |
+
|
79 |
+
output_audios = eval_model.generate(**batch, **gen_kwargs)
|
80 |
+
return output_audios
|
81 |
|
82 |
+
with torch.no_grad():
|
83 |
+
generated_audios = generate_step(batch, accelerator)
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
# Post-process the generated audio
|
86 |
+
audio_arr = generated_audios.cpu().numpy().squeeze()
|
87 |
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
|
88 |
|
89 |
return "parler_tts_out.wav"
|
90 |
|
91 |
+
# Gradio interface setup
|
92 |
+
default_prompt = "thought no beef im hate to get murder right in these streets i told yall niggins is dead fucking green tbs and tsg my shit only you cant beat out if you aint going to aim and squeeze take your mvp out the game just like a referee im talking about my life you just rapping on beats i be clapping on streets theyre using technology to try to find where the bullets coming from they wont find those z nope because im a smooth criminal i got some screwed loose because im a sick of the"
|
93 |
+
default_description = "A male vocalist delivers an energetic and passionate freestyle in a medium-fast tempo, showcasing an enthusiastic and emotional performance with emphatic expression, conveying a youthful and groovy vibe throughout the track."
|
94 |
default_seed = "456"
|
95 |
|
96 |
interface = gr.Interface(
|
|
|
99 |
gr.Textbox(label="Prompt", value=default_prompt),
|
100 |
gr.Textbox(label="Description", value=default_description),
|
101 |
gr.Textbox(label="Seed", value=default_seed),
|
102 |
+
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=1.0),
|
103 |
gr.Slider(label="Max Length", minimum=256, maximum=5120, step=256, value=2580),
|
104 |
gr.Dropdown(label="Do Sample", choices=[True, False], value=True)
|
105 |
],
|
106 |
outputs=gr.Audio(label="Generated Audio"),
|
107 |
title="Parler TTS Audio Generation",
|
108 |
+
description="Generate vocals using the Parler TTS model. Provide a prompt, description, and seed to generate the corresponding audio."
|
109 |
)
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
+
interface.launch()
|