PicoAudio / inference.py
ZeyuXie's picture
Upload 5 files
cb0c99a verified
raw
history blame
3.78 kB
import os
import json
import random
import argparse
import soundfile as sf
import numpy as np
import torch
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def parse_args():
parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
parser.add_argument(
"--text", '-t', type=str, default="spraying two times then gunshot three times.",
help="free-text caption."
)
parser.add_argument(
"--timestamp_caption", '-c', type=str,
default=None,
#default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'."
)
parser.add_argument(
"--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
help="Path for experiment."
)
parser.add_argument(
"--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt',
help="Path for clap."
)
parser.add_argument(
"--seed", type=int, default=0,
help="seed.",
)
args = parser.parse_args()
args.original_args = os.path.join(args.exp_path, "summary.jsonl")
args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt")
return args
def main():
args = parse_args()
train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Step1: preprocess via llm
if args.timestamp_caption == None:
#args.timestamp_caption = preprocess_gpt(args.text)
args.timestamp_caption = preprocess_gemini(args.text)
# Load Models #
print("------Load model")
name = "audioldm-s-full"
vae, stft = build_pretrained_models(name)
vae, stft = vae.cuda(), stft.cuda()
model = PicoDiffusion(
scheduler_name=train_args.scheduler_name,
unet_model_config_path=train_args.unet_model_config,
snr_gamma=train_args.snr_gamma,
freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt,
diffusion_pt=args.diffusion_pt,
).cuda().eval()
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
# Generate #
num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10
output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized",
f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}")
os.makedirs(output_dir, exist_ok=True)
print("------Diffusion begin!")
with torch.no_grad():
latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True)
mel = vae.decode_first_stage(latents)
wave = vae.decode_to_waveform(mel)
sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav")
if __name__ == "__main__":
main()