FreeNoise / app.py
Anonymous
update videocrafter2
1706095
raw
history blame
11.7 kB
import gradio as gr
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling_freenoise,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
def infer(prompt, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps):
window_size = 16
window_stride = 4
if output_size == "320x512":
width = 512
height = 320
ckpt_dir_512 = "checkpoints/base_512_v2"
ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt"
config_512 = "configs/inference_t2v_tconv512_v2.0_freenoise.yaml"
config_512 = OmegaConf.load(config_512)
model_config_512 = config_512.pop("model", OmegaConf.create())
model_512 = instantiate_from_config(model_config_512)
model_512 = model_512.cuda()
if not os.path.exists(ckpt_path_512):
os.makedirs(ckpt_dir_512, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512)
try:
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
except:
hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True)
model_512 = load_model_checkpoint(model_512, ckpt_path_512)
model_512.eval()
model = model_512
fps = 12
if output_size == "576x1024":
width = 1024
height = 576
ckpt_dir_1024 = "checkpoints/base_1024_v1"
ckpt_path_1024 = "checkpoints/base_1024_v1/model.ckpt"
config_1024 = "configs/inference_t2v_1024_v1.0_freenoise.yaml"
config_1024 = OmegaConf.load(config_1024)
model_config_1024 = config_1024.pop("model", OmegaConf.create())
model_1024 = instantiate_from_config(model_config_1024)
model_1024 = model_1024.cuda()
if not os.path.exists(ckpt_path_1024):
os.makedirs(ckpt_dir_1024, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024)
try:
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
except:
hf_hub_download(repo_id="VideoCrafter/Text2Video-1024", filename="model.ckpt", local_dir=ckpt_dir_1024, force_download=True)
model_1024 = load_model_checkpoint(model_1024, ckpt_path_1024)
model_1024.eval()
model = model_1024
fps = 28
num_frames = min(num_frames, 36)
elif output_size == "256x256":
width = 256
height = 256
ckpt_dir_256 = "checkpoints/base_256_v1"
ckpt_path_256 = "checkpoints/base_256_v1/model.ckpt"
config_256 = "configs/inference_t2v_tconv256_v1.0_freenoise.yaml"
config_256 = OmegaConf.load(config_256)
model_config_256 = config_256.pop("model", OmegaConf.create())
model_256 = instantiate_from_config(model_config_256)
model_256 = model_256.cuda()
if not os.path.exists(ckpt_path_256):
os.makedirs(ckpt_dir_256, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256)
try:
model_256 = load_model_checkpoint(model_256, ckpt_path_256)
except:
hf_hub_download(repo_id="VideoCrafter/Text2Video-256", filename="model.ckpt", local_dir=ckpt_dir_256, force_download=True)
model_256 = load_model_checkpoint(model_256, ckpt_path_256)
model_256.eval()
model = model_256
fps = 8
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
height=height,
width=width,
frames=num_frames,
fps=fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
window_size=window_size,
window_stride=window_stride,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
x_T_total = torch.randn(
[args.n_samples, 1, channels, frames, h, w], device=model.device
).repeat(1, args.bs, 1, 1, 1, 1)
for frame_index in range(args.window_size, args.frames, args.window_stride):
list_index = list(
range(
frame_index - args.window_size,
frame_index + args.window_stride - args.window_size,
)
)
random.shuffle(list_index)
x_T_total[
:, :, :, frame_index : frame_index + args.window_stride
] = x_T_total[:, :, :, list_index]
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
batch_samples = batch_ddim_sampling_freenoise(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
x_T_total=x_T_total,
)
video_path = "output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
print(video_path)
return video_path
examples = [
["A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect",],
["A corgi is swimming quickly",],
["A bigfoot walking in the snowstorm",],
["Campfire at night in a snowy forest with starry sky in the background",],
["A panda is surfing in the universe",],
]
css = """
#col-container {max-width: 640px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">FreeNoise (Longer Text-to-Video)</h1>
<p style="text-align: center;">
FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling (ICLR 2024)
</p>
<p style="text-align: center;">
<a href="https://arxiv.org/abs/2310.15169" target="_blank"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="http://haonanqiu.com/projects/FreeNoise.html" target="_blank"><b>[Project Page]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/AILab-CVC/FreeNoise" target="_blank"><b>[Code]</b></a>
</p>
"""
)
prompt_in = gr.Textbox(label="Prompt", placeholder="A chihuahua in astronaut suit floating in space, cinematic lighting, glow effect")
with gr.Row():
with gr.Accordion('FreeNoise Parameters (feel free to adjust these parameters based on your prompt): ', open=False):
with gr.Row():
output_size = gr.Dropdown(["320x512", "576x1024", "256x256"], value="320x512", label="Output Size", info="250s for 512 model, 900s for 1024 model (32 frames). Recovering from sleeping will take more time to download ckpt")
with gr.Row():
num_frames = gr.Slider(label='Frames (a multiple of 4), max 36 for 1024 model',
minimum=16,
maximum=64,
step=4,
value=32)
ddim_steps = gr.Slider(label='DDIM Steps',
minimum=5,
maximum=200,
step=1,
value=50)
with gr.Row():
unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
minimum=1.0,
maximum=20.0,
step=0.1,
value=12.0)
save_fps = gr.Slider(label='Save FPS',
minimum=1,
maximum=30,
step=1,
value=10)
with gr.Row():
seed = gr.Slider(label='Random Seed',
minimum=0,
maximum=10000,
step=1,
value=123)
submit_btn = gr.Button("Generate", variant='primary')
video_result = gr.Video(label="Video Output")
gr.Examples(examples=examples, inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps])
submit_btn.click(fn=infer,
inputs=[prompt_in, output_size, seed, num_frames, ddim_steps, unconditional_guidance_scale, save_fps],
outputs=[video_result],
api_name="zrscp")
demo.queue(max_size=12).launch(show_api=True)