AkhilTolani commited on
Commit
5cb8d27
1 Parent(s): 29af7a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -33
app.py CHANGED
@@ -1,43 +1,42 @@
1
  import gradio as gr
2
- from parler_tts import ParlerTTSForConditionalGeneration
3
  from transformers import AutoTokenizer, set_seed
4
  import soundfile as sf
5
  import torch
6
  import os
 
 
7
 
8
  os.system("bash install.sh")
9
 
10
- device = "cpu"
11
- if torch.cuda.is_available():
12
- device = "cuda:0"
13
- if torch.backends.mps.is_available():
14
- device = "mps"
15
- if torch.xpu.is_available():
16
- device = "xpu"
17
 
18
- torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
19
- model = ParlerTTSForConditionalGeneration.from_pretrained("AkhilTolani/parler-tts-finetune-vocals-only-large-18720-steps", torch_dtype=torch_dtype, attn_implementation="sdpa").to(device)
 
 
 
 
 
 
 
 
20
 
21
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
22
 
23
  def generate_audio(prompt, description, seed, temperature, max_length, do_sample):
24
- # Set the seed for reproducibility
25
  seed = int(seed)
26
- torch.manual_seed(seed)
27
- if torch.cuda.is_available():
28
- torch.cuda.manual_seed_all(seed)
29
- if torch.backends.mps.is_available():
30
- torch.backends.mps.manual_seed(seed)
31
- if torch.xpu.is_available():
32
- torch.xpu.manual_seed(seed)
33
 
34
  input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
35
  prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
36
 
37
- # Define num_codebooks from the model configuration
38
  num_codebooks = model.decoder.config.num_codebooks
39
 
40
- # Set up generation arguments
41
  gen_kwargs = {
42
  "do_sample": do_sample,
43
  "temperature": temperature,
@@ -45,22 +44,53 @@ def generate_audio(prompt, description, seed, temperature, max_length, do_sample
45
  "min_new_tokens": num_codebooks + 1,
46
  }
47
 
48
- set_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Generate the output
51
- generation = model.generate(
52
- input_ids=input_ids,
53
- prompt_input_ids=prompt_input_ids,
54
- **gen_kwargs
55
- ).to(torch.float32)
56
 
57
- audio_arr = generation.cpu().numpy().squeeze()
 
58
  sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
59
 
60
  return "parler_tts_out.wav"
61
 
62
- default_prompt = "free monster when you fucking hoes walk in the room chilettoes sitting up straight my fucking pasha been exposed somoomits start to hurt ive been shitting real gold doing dumb shit you never know or never seen they call me leopard i stay with the green"
63
- default_description = "Experience the vibrant energy of a hip hop track featuring a male rapper delivering smooth verses over a catchy synth lead melody, supported by punchy kicks, deep 808 bass, claps, and shimmering hi-hats. The song transitions seamlessly to a female vocalist singing melodically alongside a raw synth bass line. Perfect for setting the mood in a solo dance session at home or keeping the party going in a lively club environment."
 
64
  default_seed = "456"
65
 
66
  interface = gr.Interface(
@@ -69,14 +99,14 @@ interface = gr.Interface(
69
  gr.Textbox(label="Prompt", value=default_prompt),
70
  gr.Textbox(label="Description", value=default_description),
71
  gr.Textbox(label="Seed", value=default_seed),
72
- gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.75),
73
  gr.Slider(label="Max Length", minimum=256, maximum=5120, step=256, value=2580),
74
  gr.Dropdown(label="Do Sample", choices=[True, False], value=True)
75
  ],
76
  outputs=gr.Audio(label="Generated Audio"),
77
  title="Parler TTS Audio Generation",
78
- description="Generate audio using the Parler TTS model. Provide a prompt, description, and seed to generate the corresponding audio."
79
  )
80
 
81
  if __name__ == "__main__":
82
- interface.launch()
 
1
  import gradio as gr
2
+ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig
3
  from transformers import AutoTokenizer, set_seed
4
  import soundfile as sf
5
  import torch
6
  import os
7
+ from accelerate import Accelerator
8
+ from accelerate.utils import set_seed, AutocastKwargs
9
 
10
  os.system("bash install.sh")
11
 
12
+ # Setup device and accelerator
13
+ accelerator = Accelerator()
14
+ device = accelerator.device
15
+ mixed_precision = "no" if device == "cpu" else "bf16"
16
+ torch_dtype = torch.float32 if device == "cpu" else torch.bfloat16
 
 
17
 
18
+ # Load model and tokenizer
19
+ model_path = "AkhilTolani/parler-tts-finetune-vocals-only-large-18720-steps"
20
+ config = ParlerTTSConfig.from_pretrained(model_path)
21
+ model = ParlerTTSForConditionalGeneration.from_pretrained(
22
+ model_path,
23
+ config=config,
24
+ torch_dtype=torch_dtype,
25
+ attn_implementation="sdpa"
26
+ )
27
+ model = accelerator.prepare(model)
28
 
29
  tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
30
 
31
  def generate_audio(prompt, description, seed, temperature, max_length, do_sample):
 
32
  seed = int(seed)
33
+ set_seed(seed)
 
 
 
 
 
 
34
 
35
  input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
36
  prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
 
 
38
  num_codebooks = model.decoder.config.num_codebooks
39
 
 
40
  gen_kwargs = {
41
  "do_sample": do_sample,
42
  "temperature": temperature,
 
44
  "min_new_tokens": num_codebooks + 1,
45
  }
46
 
47
+ # Prepare for generation
48
+ batch = {
49
+ "input_ids": input_ids,
50
+ "prompt_input_ids": prompt_input_ids,
51
+ }
52
+
53
+ autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "no"))
54
+
55
+ # Generation step
56
+ def generate_step(batch, accelerator):
57
+ batch.pop("decoder_attention_mask", None)
58
+ eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
59
+
60
+ with accelerator.autocast(**autocast_kwargs):
61
+ if mixed_precision == "fp16":
62
+ encoder_outputs = eval_model.text_encoder(
63
+ input_ids=batch.get("input_ids"),
64
+ attention_mask=batch.get("attention_mask", None)
65
+ )
66
+ encoder_hidden_states = encoder_outputs.last_hidden_state
67
+ if (
68
+ config.text_encoder.hidden_size != config.decoder.hidden_size
69
+ and config.decoder.cross_attention_hidden_size is None
70
+ ):
71
+ encoder_hidden_states = eval_model.enc_to_dec_proj(encoder_hidden_states)
72
+
73
+ if batch.get("attention_mask", None) is not None:
74
+ encoder_hidden_states = encoder_hidden_states * batch.get("attention_mask", None)[..., None]
75
+
76
+ encoder_outputs.last_hidden_state = encoder_hidden_states
77
+ batch["encoder_outputs"] = encoder_outputs
78
+
79
+ output_audios = eval_model.generate(**batch, **gen_kwargs)
80
+ return output_audios
81
 
82
+ with torch.no_grad():
83
+ generated_audios = generate_step(batch, accelerator)
 
 
 
 
84
 
85
+ # Post-process the generated audio
86
+ audio_arr = generated_audios.cpu().numpy().squeeze()
87
  sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
88
 
89
  return "parler_tts_out.wav"
90
 
91
+ # Gradio interface setup
92
+ default_prompt = "thought no beef im hate to get murder right in these streets i told yall niggins is dead fucking green tbs and tsg my shit only you cant beat out if you aint going to aim and squeeze take your mvp out the game just like a referee im talking about my life you just rapping on beats i be clapping on streets theyre using technology to try to find where the bullets coming from they wont find those z nope because im a smooth criminal i got some screwed loose because im a sick of the"
93
+ default_description = "A male vocalist delivers an energetic and passionate freestyle in a medium-fast tempo, showcasing an enthusiastic and emotional performance with emphatic expression, conveying a youthful and groovy vibe throughout the track."
94
  default_seed = "456"
95
 
96
  interface = gr.Interface(
 
99
  gr.Textbox(label="Prompt", value=default_prompt),
100
  gr.Textbox(label="Description", value=default_description),
101
  gr.Textbox(label="Seed", value=default_seed),
102
+ gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=1.0),
103
  gr.Slider(label="Max Length", minimum=256, maximum=5120, step=256, value=2580),
104
  gr.Dropdown(label="Do Sample", choices=[True, False], value=True)
105
  ],
106
  outputs=gr.Audio(label="Generated Audio"),
107
  title="Parler TTS Audio Generation",
108
+ description="Generate vocals using the Parler TTS model. Provide a prompt, description, and seed to generate the corresponding audio."
109
  )
110
 
111
  if __name__ == "__main__":
112
+ interface.launch()