Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,270 Bytes
fcb4edd 2277572 fcb4edd 610a4e4 51db498 fdf117e fcb4edd f479bfc fcb4edd fdf117e fcb4edd f479bfc fdf117e fcb4edd f479bfc fcb4edd fdf117e fcb4edd fdf117e fcb4edd fdf117e fcb4edd fdf117e fcb4edd fdf117e fcb4edd fdf117e fcb4edd 90cfa64 2277572 4cb6a68 ac1c798 fdf117e fcb4edd f479bfc fcb4edd 6081ae2 fcb4edd fdf117e fcb4edd f479bfc fcb4edd f479bfc fcb4edd 90cfa64 f479bfc fdf117e 5484d62 fdf117e f479bfc 85dd4ab 90cfa64 fdf117e f479bfc fdf117e f479bfc df9e701 f479bfc df9e701 fdf117e df9e701 f479bfc cf5b2d5 fdf117e fcb4edd 2a39288 fcb4edd 58c884a fcb4edd 58c884a 2a39288 58c884a 2a39288 fcb4edd fdf117e fcb4edd fdf117e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import gradio as gr
import spaces
import torch
import gc
from huggingface_hub import snapshot_download
# import argparse
snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
from diffusers.utils import load_image, export_to_video
from diffusers import UNetSpatioTemporalConditionModel
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
from attn_ctrl.attention_control import (AttentionStore,
register_temporal_self_attention_control,
register_temporal_self_attention_flip_control,
)
pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
pretrained_model_name_or_path,
scheduler=noise_scheduler,
variant="fp16",
torch_dtype=torch.float16,
)
ref_unet = pipe.ori_unet
state_dict = pipe.unet.state_dict()
# computing delta w
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
checkpoint_dir,
subfolder="unet",
torch_dtype=torch.float16,
)
assert finetuned_unet.config.num_frames==14
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid",
subfolder="unet",
variant='fp16',
torch_dtype=torch.float16,
)
finetuned_state_dict = finetuned_unet.state_dict()
ori_state_dict = ori_unet.state_dict()
for name, param in finetuned_state_dict.items():
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
delta_w = param - ori_state_dict[name]
state_dict[name] = state_dict[name] + delta_w
pipe.unet.load_state_dict(state_dict)
controller_ref= AttentionStore()
register_temporal_self_attention_control(ref_unet, controller_ref)
controller = AttentionStore()
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
device = "cuda"
pipe = pipe.to(device)
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
# Custom CUDA memory management function
def cuda_memory_cleanup():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
@spaces.GPU(duration=90)
def infer(frame1_path, frame2_path, progress=gr.Progress(track_tqdm=True)):
seed = 42
num_inference_steps = 10
noise_injection_steps = 0
noise_injection_ratio = 0.5
weighted_average = False
generator = torch.Generator(device)
if seed is not None:
generator = generator.manual_seed(seed)
frame1 = load_image(frame1_path)
frame1 = frame1.resize((512, 288))
frame2 = load_image(frame2_path)
frame2 = frame2.resize((512, 288))
cuda_memory_cleanup()
frames = pipe(image1=frame1, image2=frame2,
num_inference_steps=num_inference_steps, # 50
generator=generator,
weighted_average=weighted_average, # True
noise_injection_steps=noise_injection_steps, # 0
noise_injection_ratio= noise_injection_ratio, # 0.5
decode_chunk_size=18
).frames[0]
# cuda_memory_cleanup()
print(f"FRAMES: {frames}")
out_dir = "result"
check_outputs_folder(out_dir)
os.makedirs(out_dir, exist_ok=True)
out_path = "result/video_result.mp4"
if out_path.endswith('.gif'):
frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
else:
export_to_video(frames, out_path, fps=7)
return out_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
gr.Markdown("## Generative Inbetweening: Adapting Image-to-Video Models for Keyframe Interpolation")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://svd-keyframe-interpolation.github.io/'>
<img src='https://img.shields.io/badge/Project-Page-Green'>
</a>
<a href='https://arxiv.org/abs/2408.15239'>
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
image_input1 = gr.Image(label="FRAME 1", type="filepath")
image_input2 = gr.Image(label="FRAME 2", type="filepath")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Video(label="Interpolated result")
gr.Examples(
examples = [
["examples/example_001/frame1.png", "examples/example_001/frame2.png"],
["examples/example_002/frame1.png", "examples/example_002/frame2.png"],
["examples/example_003/frame1.png", "examples/example_003/frame2.png"],
["examples/example_004/frame1.png", "examples/example_004/frame2.png"]
],
inputs = [image_input1, image_input2]
)
submit_btn.click(
fn = infer,
inputs = [image_input1, image_input2],
outputs = [output],
show_api = False
)
demo.queue().launch(show_api=False, show_error=True) |