jbilcke-hf's picture
jbilcke-hf HF staff
Update app.py
fb2478e verified
raw
history blame
No virus
5.69 kB
import spaces
import gradio as gr
# import gradio.helpers
import torch
import os
from glob import glob
from pathlib import Path
from typing import Optional
from PIL import Image
from diffusers.utils import load_image, export_to_video
from pipeline import StableVideoDiffusionPipeline
import random
from safetensors import safe_open
from lcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
def get_safetensors_files():
models_dir = "./safetensors"
safetensors_files = [
f for f in os.listdir(models_dir) if f.endswith(".safetensors")
]
return safetensors_files
def model_select(selected_file):
print("load model weights", selected_file)
pipe.unet.cpu()
file_path = os.path.join("./safetensors", selected_file)
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
missing, unexpected = pipe.unet.load_state_dict(state_dict, strict=True)
pipe.unet.cuda()
del state_dict
return
noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=700.0,
sigma_data=1.0,
s_noise=1.0,
rho=7,
clip_denoised=False,
)
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
scheduler=noise_scheduler,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.to("cuda")
pipe.enable_model_cpu_offload() # for smaller cost
model_select("AnimateLCM-SVD-xt-1.1.safetensors")
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) # for faster inference
max_64_bit_int = 2**63 - 1
@spaces.GPU
def sample(
image: Image,
seed: Optional[int] = 42,
randomize_seed: bool = False,
motion_bucket_id: int = 80,
fps_id: int = 8,
max_guidance_scale: float = 1.2,
min_guidance_scale: float = 1,
width: int = 1024,
height: int = 576,
num_inference_steps: int = 4,
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
output_folder: str = "outputs_gradio",
):
if image.mode == "RGBA":
image = image.convert("RGB")
if randomize_seed:
seed = random.randint(0, max_64_bit_int)
generator = torch.manual_seed(seed)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
with torch.autocast("cuda"):
frames = pipe(
image,
decode_chunk_size=decoding_t,
generator=generator,
motion_bucket_id=motion_bucket_id,
height=height,
width=width,
num_inference_steps=num_inference_steps,
min_guidance_scale=min_guidance_scale,
max_guidance_scale=max_guidance_scale,
).frames[0]
export_to_video(frames, video_path, fps=fps_id)
torch.manual_seed(seed)
return video_path, seed
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="pil")
generate_btn = gr.Button("Generate")
video = gr.Video()
seed = gr.Slider(
label="Seed",
value=42,
randomize=False,
minimum=0,
maximum=max_64_bit_int,
step=1,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=80,
minimum=1,
maximum=255,
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=8,
minimum=5,
maximum=30,
)
# note: we want something that is close to 16:9 (1.7777)
# 576 / 320 = 1.8
# 448 / 256 = 1.75
width = gr.Slider(
label="Width of input image",
info="It should be divisible by 64",
value=576, # 256, 320, 384, 448
minimum=256,
maximum=2048,
step=64,
)
height = gr.Slider(
label="Height of input image",
info="It should be divisible by 64",
value=320, # 256, 320, 384, 448
minimum=256,
maximum=1152,
)
max_guidance_scale = gr.Slider(
label="Max guidance scale",
info="classifier-free guidance strength",
value=1.2,
minimum=1,
maximum=2,
)
min_guidance_scale = gr.Slider(
label="Min guidance scale",
info="classifier-free guidance strength",
value=1,
minimum=1,
maximum=1.5,
)
num_inference_steps = gr.Slider(
label="Num inference steps",
info="steps for inference",
value=4,
minimum=1,
maximum=20,
step=1,
)
generate_btn.click(
fn=sample,
inputs=[
image,
seed,
randomize_seed,
motion_bucket_id,
fps_id,
max_guidance_scale,
min_guidance_scale,
width,
height,
num_inference_steps,
],
outputs=[video, seed],
api_name="video",
)
if __name__ == "__main__":
demo.queue()
demo.launch(show_error=True)