import gradio as gr import os import sys import argparse import random from omegaconf import OmegaConf import torch import torchvision from pytorch_lightning import seed_everything from huggingface_hub import hf_hub_download sys.path.insert(0, "scripts/evaluation") from funcs import ( batch_ddim_sampling_freenoise, load_model_checkpoint, ) from utils.utils import instantiate_from_config def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps): window_size = 16 window_stride = 4 if output_size == "320x512": width = 512 height = 320 ckpt_dir_512 = "checkpoints/base_512_v2" ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt" config_512 = "configs/inference_t2v_tconv512_v2.0_freenoise.yaml" config_512 = OmegaConf.load(config_512) model_config_512 = config_512.pop("model", OmegaConf.create()) model_512 = instantiate_from_config(model_config_512) model_512 = model_512.cuda() if not os.path.exists(ckpt_path_512): os.makedirs(ckpt_dir_512, exist_ok=True) hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512) try: model_512 = load_model_checkpoint(model_512, ckpt_path_512) except: hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True) model_512 = load_model_checkpoint(model_512, ckpt_path_512) model_512.eval() model = model_512 fps = 12 if output_size == "576x1024": width = 1024 height = 576 ckpt_dir_1024 = "checkpoints/base_1024_v1" ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt" config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml" config_1024 = OmegaConf.load(config_1024) model_config_1024 = config_1024.pop("model", OmegaConf.create()) model_1024 = instantiate_from_config(model_config_1024) model_1024 = model_1024.cuda() if not os.path.exists(ckpt_path_1024): os.makedirs(ckpt_dir_1024, exist_ok=True) hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024) try: model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024) except: hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024, force_download=True) model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024) model_1024.eval() model = model_1024 fps = 28 num_frames = min(num_frames, 36) elif output_size == "256x256": width = 256 height = 256 ckpt_dir_256 = "checkpoints/base_256_v1" ckpt_path_256 = "checkpoints/base_256_v1/model.ckpt" config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml" config_256 = OmegaConf.load(config_256) model_config_256 = config_256.pop("model", OmegaConf.create()) model_256 = instantiate_from_config(model_config_256) model_256 = model_256.cuda() if not os.path.exists(ckpt_path_256): os.makedirs(ckpt_dir_256, exist_ok=True) hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256) try: model_256 = load_model_checkpoint(model_256, ckpt_path_256) except: hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256, force_download=True) model_256 = load_model_checkpoint(model_256, ckpt_path_256) model_256.eval() model = model_256 fps = 8 if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}") seed_everything(seed) args = argparse.Namespace( mode="base", savefps=save_fps, n_samples=1, ddim_steps=ddim_steps, ddim_eta=0.0, bs=1, height=height, width=width, frames=num_frames, fps=fps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale_temporal=None, cond_input=None, window_size=window_size, window_stride=window_stride, ) ## latent noise shape h, w = args.height // 8, args.width // 8 frames = model.temporal_length if args.frames < 0 else args.frames channels = model.channels x_T_total = torch.randn( [args.n_samples, 1, channels, frames, h, w], device=model.device ).repeat(1, args.bs, 1, 1, 1, 1) for frame_index in range(args.window_size, args.frames, args.window_stride): list_index = list( range( frame_index - args.window_size, frame_index + args.window_stride - args.window_size, ) ) random.shuffle(list_index) x_T_total[ :, :, :, frame_index : frame_index + args.window_stride ] = x_T_total[:, :, :, list_index] batch_size = 1 noise_shape = [batch_size, channels, frames, h, w] fps = torch.tensor([args.fps] * batch_size).to(model.device).long() prompts = [prompt] text_emb = model.get_learned_conditioning(prompts) cond = {"c_crossattn": [text_emb], "fps": fps} ## inference batch_samples = batch_ddim_sampling_freenoise( model, cond, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, args=args, x_T_total=x_T_total, ) video_path = "output.mp4" vid_tensor = batch_samples[0] video = vid_tensor.detach().cpu() video = torch.clamp(video.float(), -1.0, 1.0) video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) for framesheet in video ] # [3, 1*h, n*w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video( video_path, grid, fps=args.savefps, video_codec="h264", options={"crf": "10"}, ) print(video_path) return video_path examples = [ ["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",], ["A corgi is swimming quickly",], ["A bigfoot walking in the snowstorm",], ["Campfire at night in a snowy forest with starry sky in the background",], ["A panda is surfing in the universe",], ] css = """ #col-container {max-width: 640px; margin-left: auto; margin-right: auto;} a {text-decoration-line: underline; font-weight: 600;} .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 15rem; height: 36px; } div#share-btn-container > div { flex-direction: row; background: black; align-items: center; } #share-btn-container:hover { background-color: #060606; } #share-btn { all: initial; color: #ffffff; font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important; right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } #share-btn-container.hidden { display: none!important; } img[src*='#center'] { display: inline-block; margin: unset; } .footer { margin-bottom: 45px; margin-top: 10px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( """
FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling (ICLR 2024)
""" ) prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect") with gr.Row(): with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False): with gr.Row(): output_size = gr.Dropdown(["320x512", "576x1024", "256x256"], value="320x512", label="Output Size", info="250s for 512 model, 900s for 1024 model (32 frames). Recovering from sleeping will take more time to download ckpt") with gr.Row(): num_frames = gr.Slider(label='Frames (a multiple of 4), max 36 for 1024 model', minimum=16, maximum=64, step=4, value=32) ddim_steps = gr.Slider(label='DDIM Steps', minimum=5, maximum=200, step=1, value=50) with gr.Row(): unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale', minimum=1.0, maximum=20.0, step=0.1, value=12.0) save_fps = gr.Slider(label='Save FPS', minimum=1, maximum=30, step=1, value=10) with gr.Row(): seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) submit_btn = gr.Button("Generate", variant='primary') video_result = gr.Video(label="Video Output") gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps]) submit_btn.click(fn=infer, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps], outputs=[video_result], api_name="zrscp") demo.queue(max_size=12).launch(show_api=True)