PicoAudio / app.py
ZeyuXie's picture
Update app.py
cc04ad5 verified
raw
history blame
6.59 kB
import os
import json
import numpy as np
import torch
import soundfile as sf
import gradio as gr
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class InferRunner:
def __init__(self, device):
vae_config = json.load(open("ckpts/ldm/vae_config.json"))
self.vae = AutoencoderKL(**vae_config).to(device)
vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
self.vae.load_state_dict(vae_weights)
train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
self.pico_model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
diffusion_pt="ckpts/pico_model/diffusion.pt",
).eval().to(device)
self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
device = "cuda" if torch.cuda.is_available() else "cpu"
runner = InferRunner(device)
event_list = get_event()
def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
with torch.no_grad():
latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
mel = runner.vae.decode_first_stage(latents)
wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
outpath = f"output.wav"
sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
return outpath
def preprocess(caption):
output = preprocess_gemini(caption)
return output, output
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## PicoAudio")
with gr.Row():
description_text = f"Support 18 events: {', '.join(event_list)}"
gr.Markdown(description_text)
with gr.Row():
gr.Markdown("## Step1")
with gr.Row():
preprocess_description_text = f"Preprocess: transfer free-text into timestamp caption via LLM. "+\
"This demo uses Gemini as the preprocessor. If any errors occur, please try a few more times. "+\
"We also provide the GPT version consistent with the paper in the file 'Files/llm_reprocessing.py'. You can use your own api_key to modify and run 'Files/inference.py' for local inference."
gr.Markdown(preprocess_description_text)
with gr.Row():
with gr.Column():
freetext_prompt = gr.Textbox(label="Free-text prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
value="a dog barks three times.",)
preprocess_run_button = gr.Button()
prompt = None
with gr.Column():
freetext_prompt_out = gr.Textbox(label="Timestamp Caption: Preprocess output")
with gr.Row():
with gr.Column():
gr.Examples(
examples = [["spraying two times then gunshot three times."],
["a dog barks three times."],
["cow mooing two times."],],
inputs = [freetext_prompt],
outputs = [prompt]
)
with gr.Column():
pass
with gr.Row():
gr.Markdown("## Step2")
with gr.Row():
generate_description_text = f"Generate audio based on timestamp caption."
gr.Markdown(generate_description_text)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Timestamp Caption: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
generate_run_button = gr.Button()
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
with gr.Column():
outaudio = gr.Audio()
preprocess_run_button.click(fn=preprocess, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
with gr.Row():
with gr.Column():
gr.Examples(
examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
["dog_barking at 0.562-2.562_4.25-6.25."],
["cow_mooing at 0.958-3.582_5.272-7.896."],],
inputs = [prompt, num_steps, guidance_scale],
outputs = [outaudio]
)
with gr.Column():
pass
demo.launch()
# description_text = f"18 events: {', '.join(event_list)}"
# prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
# value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
# outaudio = gr.Audio()
# num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
# guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
# gr_interface = gr.Interface(
# fn=infer,
# inputs=[prompt, num_steps, guidance_scale],
# outputs=[outaudio],
# title="PicoAudio",
# description=description_text,
# allow_flagging=False,
# examples=[
# ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
# ["dog_barking at 0.562-2.562_4.25-6.25."],
# ["cow_mooing at 0.958-3.582_5.272-7.896."],
# ],
# cache_examples="lazy", # Turn on to cache.
# )
# gr_interface.queue(10).launch()