File size: 4,654 Bytes
1f26343
5cb8d27
08c814b
2de10d6
 
 
5cb8d27
c0a8009
1f26343
2de10d6
1f26343
c0a8009
5cb8d27
 
 
 
2de10d6
5cb8d27
 
 
 
 
 
 
 
 
 
944d6cb
29af7a4
2de10d6
08c814b
2029e38
5cb8d27
2029e38
2de10d6
 
 
 
 
 
08c814b
 
 
2de10d6
 
 
c0a8009
5cb8d27
 
 
 
 
 
 
 
 
c0a8009
 
 
5cb8d27
ba67574
 
 
 
5cb8d27
c0a8009
 
5cb8d27
2de10d6
5cb8d27
 
1f26343
c0a8009
 
 
 
 
 
 
 
eb76c8e
 
c0a8009
 
 
5cb8d27
c0a8009
2de10d6
1f26343
2de10d6
e658e7c
c0a8009
5cb8d27
 
2029e38
 
2de10d6
 
2029e38
bb4ba70
 
08c814b
eb76c8e
08c814b
 
2029e38
2de10d6
 
c0a8009
2de10d6
1f26343
2de10d6
5cb8d27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSConfig
from transformers import AutoTokenizer, set_seed
import soundfile as sf
import torch
import os
from accelerate import Accelerator
from accelerate.utils import set_seed

os.system("bash install.sh")

# Setup accelerator
accelerator = Accelerator()
device = accelerator.device
mixed_precision = "no" if device == "cpu" else "bf16"
torch_dtype = torch.float32 if device == "cpu" else torch.bfloat16

# Load model and tokenizer
model_path = "AkhilTolani/parler-tts-finetune-vocals-only-large-18720-steps"
config = ParlerTTSConfig.from_pretrained(model_path)
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_path,
    config=config,
    torch_dtype=torch_dtype,
    attn_implementation="sdpa"
)
model = accelerator.prepare(model)

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")

def generate_audio(prompt, description, seed, temperature, max_length, do_sample):
    seed = int(seed)
    set_seed(seed)

    input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
    prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    num_codebooks = model.decoder.config.num_codebooks

    gen_kwargs = {
        "do_sample": do_sample,
        "temperature": temperature,
        "max_length": max_length,
        "min_new_tokens": num_codebooks + 1,
    }

    # Prepare batch
    batch = {
        "input_ids": input_ids,
        "prompt_input_ids": prompt_input_ids,
    }

    def generate_step(batch, accelerator):
        batch.pop("decoder_attention_mask", None)
        eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=True)
        
        # Handle torch.compile if it was used in training
        if hasattr(eval_model, '_orig_mod'):
            eval_model = eval_model._orig_mod

        if mixed_precision != "no":
            with accelerator.autocast():
                output_audios = eval_model.generate(**batch, **gen_kwargs)
        else:
            output_audios = eval_model.generate(**batch, **gen_kwargs)
        
        output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
        return output_audios

    with torch.no_grad():
        generated_audios = generate_step(batch, accelerator)

    # Gather and pad predictions
    generated_audios, input_ids, prompts = accelerator.pad_across_processes(
        (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
    )
    generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
        (generated_audios, input_ids, prompts)
    )

    # Convert to CPU and float32
    generated_audios = generated_audios.cpu().float()
    input_ids = input_ids.cpu()
    prompts = prompts.cpu()

    # Post-process the generated audio
    audio_arr = generated_audios[0].numpy().squeeze()  # Take the first sample if multiple were generated
    sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

    return "parler_tts_out.wav"

# Gradio interface setup (unchanged)
default_prompt = "thought no beef im hate to get murder right in these streets i told yall niggins is dead fucking green tbs and tsg my shit only you cant beat out if you aint going to aim and squeeze take your mvp out the game just like a referee im talking about my life you just rapping on beats i be clapping on streets theyre using technology to try to find where the bullets coming from they wont find those z nope because im a smooth criminal i got some screwed loose because im a sick of the"
default_description = "A male vocalist delivers an energetic and passionate freestyle in a medium-fast tempo, showcasing an enthusiastic and emotional performance with emphatic expression, conveying a youthful and groovy vibe throughout the track."
default_seed = "456"

interface = gr.Interface(
    fn=generate_audio,
    inputs=[
        gr.Textbox(label="Prompt", value=default_prompt), 
        gr.Textbox(label="Description", value=default_description),
        gr.Textbox(label="Seed", value=default_seed),
        gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.75),
        gr.Slider(label="Max Length", minimum=256, maximum=5120, step=256, value=2580),
        gr.Dropdown(label="Do Sample", choices=[True, False], value=True)
    ],
    outputs=gr.Audio(label="Generated Audio"),
    title="Parler TTS Audio Generation",
    description="Generate audio using the Parler TTS model. Provide a prompt, description, and seed to generate the corresponding audio."
)

if __name__ == "__main__":
    interface.launch()