diffusers / app.py
seawolf2357's picture
Update app.py
f028eb9 verified
raw
history blame
4.51 kB
import torch
import gradio as gr
from diffusers import AnimateDiffSparseControlNetPipeline, AutoencoderKL, MotionAdapter, SparseControlNetModel, AnimateDiffPipeline, EulerAncestralDiscreteScheduler
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, load_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_video(prompt, negative_prompt, num_inference_steps, conditioning_frame_indices, controlnet_conditioning_scale):
motion_adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3", torch_dtype=torch.float16).to(device)
controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectrl-scribble", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffSparseControlNetPipeline.from_pretrained(
"SG161222/Realistic_Vision_V5.1_noVAE",
motion_adapter=motion_adapter,
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
).to(device)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, beta_schedule="linear", algorithm_type="dpmsolver++", use_karras_sigmas=True)
image_files = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png"
]
conditioning_frames = [load_image(img_file) for img_file in image_files]
# Ensure conditioning_frame_indices is a list of integers
conditioning_frame_indices = eval(conditioning_frame_indices)
controlnet_conditioning_scale = float(controlnet_conditioning_scale)
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
conditioning_frames=conditioning_frames,
controlnet_conditioning_scale=controlnet_conditioning_scale,
controlnet_frame_indices=conditioning_frame_indices,
generator=torch.Generator().manual_seed(1337),
).frames[0]
export_to_gif(video, "output.gif")
return "output.gif"
def generate_simple_video(prompt):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16).to(device)
pipe = AnimateDiffPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16).to(device)
pipe.scheduler = EulerAncestralDiscreteScheduler(
beta_schedule="linear",
beta_start=0.00085,
beta_end=0.012,
)
pipe.enable_free_noise()
pipe.vae.enable_slicing()
pipe.enable_model_cpu_offload()
frames = pipe(
prompt,
num_frames=64,
num_inference_steps=20,
guidance_scale=7.0,
decode_chunk_size=2,
).frames[0]
export_to_gif(frames, "simple_output.gif")
return "simple_output.gif"
demo1 = gr.Interface(
fn=generate_video,
inputs=[
gr.Textbox(label="Prompt", value="an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality"),
gr.Textbox(label="Negative Prompt", value="low quality, worst quality, letterboxed"),
gr.Slider(label="Number of Inference Steps", minimum=1, maximum=100, step=1, value=25),
gr.Textbox(label="Conditioning Frame Indices", value="[0, 8, 15]"),
gr.Slider(label="ControlNet Conditioning Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
],
outputs=gr.Image(label="Generated Video"),
title="Generate Video with AnimateDiffSparseControlNetPipeline",
description="Generate a video using the AnimateDiffSparseControlNetPipeline."
)
demo2 = gr.Interface(
fn=generate_simple_video,
inputs=gr.Textbox(label="Prompt", value="An astronaut riding a horse on Mars."),
outputs=gr.Image(label="Generated Simple Video"),
title="Generate Simple Video with AnimateDiff",
description="Generate a simple video using the AnimateDiffPipeline."
)
demo = gr.TabbedInterface([demo1, demo2], ["Advanced Video Generation", "Simple Video Generation"])
demo.launch()
#demo.launch(server_name="0.0.0.0", server_port=7910)