Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,485 Bytes
b9d6819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import torch
import random
import numpy as np
import gradio as gr
import soundfile as sf
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DDIMScheduler
from src.models.conditioners import MaskDiT
from src.modules.autoencoder_wrapper import Autoencoder
from src.inference import inference
from src.utils import load_yaml_with_includes
# Load model and configs
def load_models(config_name, ckpt_path, vae_path, device):
params = load_yaml_with_includes(config_name)
# Load codec model
autoencoder = Autoencoder(ckpt_path=vae_path,
model_type=params['autoencoder']['name'],
quantization_first=params['autoencoder']['q_first']).to(device)
autoencoder.eval()
# Load text encoder
tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
text_encoder.eval()
# Load main U-Net model
unet = MaskDiT(**params['model']).to(device)
unet.load_state_dict(torch.load(ckpt_path)['model'])
unet.eval()
# Load noise scheduler
noise_scheduler = DDIMScheduler(**params['diff'])
return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params
MAX_SEED = np.iinfo(np.int32).max
# Model and config paths
config_name = 'ckpts/ezaudio-xl.yml'
ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
vae_path = 'ckpts/vae/1m.pt'
save_path = 'output/'
os.makedirs(save_path, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
device)
latents = torch.randn((1, 128, 128), device=device)
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
_ = noise_scheduler.add_noise(latents, noise, timesteps)
# Inference function
def generate_audio(text, length,
guidance_scale, guidance_rescale, ddim_steps, eta,
random_seed, randomize_seed):
neg_text = None
length = length * params['autoencoder']['latent_sr']
if randomize_seed:
random_seed = random.randint(0, MAX_SEED)
pred = inference(autoencoder, unet, None, None,
tokenizer, text_encoder,
params, noise_scheduler,
text, neg_text,
length,
guidance_scale, guidance_rescale,
ddim_steps, eta, random_seed,
device)
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
# output_file = f"{save_path}/{text}.wav"
# sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])
return params['autoencoder']['sr'], pred
# Gradio Interface
def gradio_interface():
# Input components
text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
# Advanced settings
guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
# Output component
output_audio = gr.Audio(label="Converted Audio", type="numpy")
# Interface
gr.Interface(
fn=generate_audio,
inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
random_seed_input, randomize_seed],
outputs=output_audio,
title="EzAudio Text-to-Audio Generator",
description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
allow_flagging="never"
).launch()
if __name__ == "__main__":
gradio_interface()
|