PicoAudio / app.py
ZeyuXie's picture
Update app.py
012fbfa verified
raw
history blame
4.12 kB
import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class InferRunner:
def __init__(self, device):
vae_config = json.load(open("ckpts/ldm/vae_config.json"))
self.vae = AutoencoderKL(**vae_config).to(device)
vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
self.vae.load_state_dict(vae_weights)
train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
self.pico_model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
diffusion_pt="ckpts/pico_model/diffusion.pt",
).eval().to(device)
self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
with torch.no_grad():
latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
mel = runner.vae.decode_first_stage(latents)
wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
outpath = f"synthesized/output.wav"
sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
return outpath
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## PicoAudio")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1.",
value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
run_button = gr.Button()
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="num_steps", minimum=1,
maximum=300, value=200, step=1)
guidance_scale = gr.Slider(
label="guidance_scale Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=8.0, value=3.0, step=0.1
)
with gr.Column():
outaudio = gr.Audio()
run_button.click(fn=infer,
inputs=[prompt, num_steps, guidance_scale],
outputs=[outaudio])
# with gr.Row():
# with gr.Column():
# gr.Examples(
# examples = [['An amateur recording features a steel drum playing in a higher register',25,5,55],
# ['An instrumental song with a caribbean feel, happy mood, and featuring steel pan music, programmed percussion, and bass',25,5,55],
# ['This musical piece features a playful and emotionally melodic male vocal accompanied by piano',25,5,55],
# ['A eerie yet calming experimental electronic track featuring haunting synthesizer strings and pads',25,5,55],
# ['A slow tempo pop instrumental piece featuring only acoustic guitar with fingerstyle and percussive strumming techniques',25,5,55]],
# inputs = [prompt, ddim_steps, scale, seed],
# outputs = [outaudio],
# )
# cache_examples="lazy", # Turn on to cache.
# with gr.Column():
# pass
demo.launch()