Spaces:
Paused
Paused
import os | |
import cv2 | |
import torch | |
import gradio as gr | |
import torchvision | |
import warnings | |
import numpy as np | |
from PIL import Image, ImageSequence | |
from moviepy.editor import VideoFileClip | |
import imageio | |
from diffusers import ( | |
TextToVideoSDPipeline, | |
AutoencoderKL, | |
DDPMScheduler, | |
DDIMScheduler, | |
UNet3DConditionModel, | |
) | |
from transformers import CLIPTokenizer, CLIPTextModel | |
from diffusers.utils import export_to_video | |
from typing import List | |
from text2vid_modded import TextToVideoSDPipelineModded | |
from invert_utils import ddim_inversion as dd_inversion | |
from gifs_filter import filter | |
import subprocess | |
import spaces | |
def load_frames(image: Image, mode='RGBA'): | |
return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)]) | |
def run_setup(): | |
try: | |
# Step 1: Install Git LFS | |
subprocess.run(["git", "lfs", "install"], check=True) | |
# Step 2: Clone the repository | |
repo_url = "https://huggingface.co/Hmrishav/t2v_sketch-lora" | |
subprocess.run(["git", "clone", repo_url], check=True) | |
# Step 3: Move the checkpoint file | |
source = "t2v_sketch-lora/checkpoint-2500" | |
destination = "./checkpoint-2500/" | |
os.rename(source, destination) | |
print("Setup completed successfully!") | |
except subprocess.CalledProcessError as e: | |
print(f"Error during setup: {e}") | |
except FileNotFoundError as e: | |
print(f"File operation error: {e}") | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
# Automatically run setup during app initialization | |
run_setup() | |
def save_gif(frames, path): | |
imageio.mimsave( | |
path, | |
[frame.astype(np.uint8) for frame in frames], | |
format="GIF", | |
duration=1 / 10, | |
loop=0 # 0 means infinite loop | |
) | |
def load_image(imgname, target_size=None): | |
pil_img = Image.open(imgname).convert('RGB') | |
if target_size: | |
if isinstance(target_size, int): | |
target_size = (target_size, target_size) | |
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS) | |
return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) | |
def prepare_latents(pipe, x_aug): | |
with torch.cuda.amp.autocast(): | |
batch_size, num_frames, channels, height, width = x_aug.shape | |
x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width) | |
latents = pipe.vae.encode(x_aug).latent_dist.sample() | |
latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3]) | |
latents = latents.permute(0, 2, 1, 3, 4) | |
return pipe.vae.config.scaling_factor * latents | |
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16): | |
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5 | |
input_img = torch.cat(input_img, dim=1) | |
latents = prepare_latents(pipe, input_img).to(torch.bfloat16) | |
inv.set_timesteps(25) | |
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype) | |
return torch.mean(id_latents, dim=2, keepdim=True) | |
def load_primary_models(pretrained_model_path): | |
return ( | |
DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"), | |
CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"), | |
CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"), | |
AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"), | |
UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"), | |
) | |
def initialize_pipeline(model: str, device: str = "cuda"): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model) | |
pipe = TextToVideoSDPipeline.from_pretrained( | |
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
scheduler=scheduler, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16), | |
vae=vae.to(device=device, dtype=torch.bfloat16), | |
unet=unet.to(device=device, dtype=torch.bfloat16), | |
) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
return pipe, pipe.scheduler | |
# Initialize the models | |
LORA_CHECKPOINT = "checkpoint-2500" | |
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
dtype = torch.bfloat16 | |
pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device) | |
pipe = TextToVideoSDPipelineModded.from_pretrained( | |
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
scheduler=pipe_inversion.scheduler, | |
tokenizer=pipe_inversion.tokenizer, | |
text_encoder=pipe_inversion.text_encoder, | |
vae=pipe_inversion.vae, | |
unet=pipe_inversion.unet, | |
).to(device) | |
def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_): | |
pipe_inversion.to(device) | |
id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype) | |
latents = id_latents.repeat(num_seeds, 1, 1, 1, 1) | |
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)] | |
video_frames = pipe( | |
prompt=caption, | |
negative_prompt="", | |
num_frames=num_frames, | |
num_inference_steps=25, | |
inv_latents=latents, | |
guidance_scale=9, | |
generator=generator, | |
lambda_=lambda_, | |
).frames | |
gifs = [] | |
for seed in range(num_seeds): | |
vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4" | |
gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif" | |
os.makedirs(os.path.dirname(vid_name), exist_ok=True) | |
os.makedirs(os.path.dirname(gif_name), exist_ok=True) | |
video_path = export_to_video(video_frames[seed], output_video_path=vid_name) | |
VideoFileClip(vid_name).write_gif(gif_name) | |
with Image.open(gif_name) as im: | |
frames = load_frames(im) | |
frames_collect = np.empty((0, 1024, 1024), int) | |
for frame in frames: | |
frame = cv2.resize(frame, (1024, 1024))[:, :, :3] | |
frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY) | |
_, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
frames_collect = np.append(frames_collect, [frame], axis=0) | |
save_gif(frames_collect, gif_name) | |
gifs.append(gif_name) | |
return gifs | |
def generate_output(image, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5) -> List[str]: | |
"""Main function to generate output GIFs""" | |
exp_dir = "static/app_tmp" | |
os.makedirs(exp_dir, exist_ok=True) | |
# Save the input image temporarily | |
temp_image_path = os.path.join(exp_dir, "temp_input.png") | |
image.save(temp_image_path) | |
# Generate the GIFs | |
generated_gifs = process_video( | |
num_frames=10, | |
num_seeds=num_seeds, | |
generator=None, | |
exp_dir=exp_dir, | |
load_name=temp_image_path, | |
caption=prompt, | |
lambda_=1 - lambda_value | |
) | |
# Apply filtering (assuming filter function is imported) | |
filtered_gifs = filter(generated_gifs, temp_image_path) | |
return filtered_gifs | |
def create_gradio_interface(): | |
with gr.Blocks(css=""" | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.example-gallery { | |
margin: 20px 0; | |
padding: 20px; | |
background: #f7f7f7; | |
border-radius: 8px; | |
} | |
.selected-example { | |
margin: 20px 0; | |
padding: 20px; | |
background: #ffffff; | |
border-radius: 8px; | |
} | |
.controls-section { | |
background: #ffffff; | |
padding: 20px; | |
margin: 20px 0; | |
border-radius: 8px; | |
} | |
.output-gallery { | |
min-height: 500px; | |
margin: 20px 0; | |
padding: 20px; | |
background: #f7f7f7; | |
border-radius: 8px; | |
} | |
.example-item { | |
border-radius: 8px; | |
overflow: hidden; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
transition: transform 0.2s; | |
cursor: pointer; | |
} | |
.example-item:hover { | |
transform: scale(1.05); | |
} | |
/* Prevent gallery images from expanding */ | |
.gallery-image { | |
height: 200px !important; | |
width: 200px !important; | |
object-fit: cover !important; | |
} | |
.generate-btn { | |
width: 100%; | |
margin-top: 1rem; | |
} | |
.generate-btn:disabled { | |
opacity: 0.7; | |
cursor: not-allowed; | |
} | |
""") as demo: | |
gr.Markdown( | |
""" | |
<div align="center" id = "user-content-toc"> | |
<img align="left" width="70" height="70" src="https://github.com/user-attachments/assets/c61cec76-3c4b-42eb-8c65-f07e0166b7d8" alt=""> | |
# [FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations](https://hmrishavbandy.github.io/flipsketch-web/) | |
## [Hmrishav Bandyopadhyay](https://hmrishavbandy.github.io/) . [Yi-Zhe Song](https://personalpages.surrey.ac.uk/y.song/) | |
</div> | |
""" | |
) | |
with gr.Tabs() as tabs: | |
# First tab: Examples (Secure) | |
with gr.Tab("Examples"): | |
gr.Markdown("## Step 1 👉 Select a sketch from the gallery of sketches") | |
examples_dir = "static/examples" | |
if os.path.exists(examples_dir): | |
example_images = [] | |
for example in os.listdir(examples_dir): | |
if example.endswith(('.png', '.jpg', '.jpeg')): | |
example_path = os.path.join(examples_dir, example) | |
example_images.append(Image.open(example_path)) | |
example_selection = gr.Gallery( | |
example_images, | |
label="Sketch Gallery", | |
elem_classes="example-gallery", | |
columns=4, | |
rows=2, | |
height="auto", | |
allow_preview=False, # Disable preview expansion | |
show_share_button=False, | |
interactive=False, | |
selected_index=None # Don't pre-select any image | |
) | |
gr.Markdown("## Step 2 👉 Describe the motion you want to generate") | |
with gr.Group(elem_classes="selected-example"): | |
with gr.Row(): | |
selected_example = gr.Image( | |
type="pil", | |
label="Selected Sketch", | |
scale=1, | |
interactive=False, | |
show_download_button=False, | |
height=300 # Fixed height for consistency | |
) | |
with gr.Column(scale=2): | |
example_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the motion...", | |
lines=3 | |
) | |
with gr.Row(): | |
example_num_seeds = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Seeds" | |
) | |
example_lambda = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Motion Strength" | |
) | |
example_generate_btn = gr.Button( | |
"Generate Animation", | |
variant="primary", | |
elem_classes="generate-btn", | |
interactive=True, | |
) | |
gr.Markdown("## Result 👉 Generated Animations ❤️") | |
example_gallery = gr.Gallery( | |
label="Results", | |
elem_classes="output-gallery", | |
columns=3, | |
rows=2, | |
height="auto", | |
allow_preview=False, # Disable preview expansion | |
show_share_button=False, | |
object_fit="cover", | |
preview=False | |
) | |
# Second tab: Upload | |
with gr.Tab("Upload Your Sketch"): | |
with gr.Group(elem_classes="selected-example"): | |
with gr.Row(): | |
upload_image = gr.Image( | |
type="pil", | |
label="Upload Your Sketch", | |
scale=1, | |
height=300, # Fixed height for consistency | |
show_download_button=False, | |
sources=["upload"], | |
) | |
with gr.Column(scale=2): | |
upload_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe what you want to generate...", | |
lines=3 | |
) | |
with gr.Row(): | |
upload_num_seeds = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Number of Variations" | |
) | |
upload_lambda = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Motion Strength" | |
) | |
upload_generate_btn = gr.Button( | |
"Generate Animation", | |
variant="primary", | |
elem_classes="generate-btn", | |
size="lg", | |
interactive=True, | |
) | |
gr.Markdown("## Result 👉 Generated Animations ❤️") | |
upload_gallery = gr.Gallery( | |
label="Results", | |
elem_classes="output-gallery", | |
columns=3, | |
rows=2, | |
height="auto", | |
allow_preview=False, # Disable preview expansion | |
show_share_button=False, | |
object_fit="cover", | |
preview=False | |
) | |
# Event handlers | |
def select_example(evt: gr.SelectData): | |
prompts = {'sketch1.png': 'The camel walks slowly', | |
'sketch2.png': 'The wine in the wine glass sways from side to side', | |
'sketch3.png': 'The squirrel is eating a nut', | |
'sketch4.png': 'The surfer surfs on the waves', | |
'sketch5.png': 'A galloping horse', | |
'sketch6.png': 'The cat walks forward', | |
'sketch7.png': 'The eagle flies in the sky', | |
'sketch8.png': 'The flower is blooming slowly', | |
'sketch9.png': 'The reindeer looks around', | |
'sketch10.png': 'The cloud floats in the sky', | |
'sketch11.png': 'The jazz saxophonist performs on stage with a rhythmic sway, his upper body sways subtly to the rhythm of the music.', | |
'sketch12.png': 'The biker rides on the road',} | |
if evt.index < len(example_images): | |
example_img = example_images[evt.index] | |
prompt_text = prompts.get(os.path.basename(example_img.filename), "") | |
return [ | |
example_img, | |
prompt_text | |
] | |
return [None, ""] | |
example_selection.select( | |
select_example, | |
None, | |
[selected_example, example_prompt] | |
) | |
example_generate_btn.click( | |
fn=generate_output, | |
inputs=[ | |
selected_example, | |
example_prompt, | |
example_num_seeds, | |
example_lambda | |
], | |
outputs=example_gallery | |
) | |
upload_generate_btn.click( | |
fn=generate_output, | |
inputs=[ | |
upload_image, | |
upload_prompt, | |
upload_num_seeds, | |
upload_lambda | |
], | |
outputs=upload_gallery | |
) | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_api=False | |
) |