Spaces:
Paused
Paused
resolve deps
Browse files- .gitignore +7 -0
- README.md +1 -1
- app_gradio.py +471 -0
- requirements.txt +1 -0
- static/app_tmp/temp_input.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoint-2500/
|
2 |
+
t2v_sketch-lora/
|
3 |
+
__pycache__/
|
4 |
+
static/app_tmp/gif_logs/*
|
5 |
+
static/app_tmp/mp4_logs/*
|
6 |
+
static/app_tmp/png_logs/*
|
7 |
+
static/uploads/*
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🚀
|
|
4 |
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
app_file:
|
8 |
pinned: false
|
9 |
---
|
10 |
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
app_file: app_gradio.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
app_gradio.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import torchvision
|
6 |
+
import warnings
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image, ImageSequence
|
9 |
+
from moviepy.editor import VideoFileClip
|
10 |
+
import imageio
|
11 |
+
from diffusers import (
|
12 |
+
TextToVideoSDPipeline,
|
13 |
+
AutoencoderKL,
|
14 |
+
DDPMScheduler,
|
15 |
+
DDIMScheduler,
|
16 |
+
UNet3DConditionModel,
|
17 |
+
)
|
18 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
19 |
+
from diffusers.utils import export_to_video
|
20 |
+
from typing import List
|
21 |
+
from text2vid_modded import TextToVideoSDPipelineModded
|
22 |
+
from invert_utils import ddim_inversion as dd_inversion
|
23 |
+
from gifs_filter import filter
|
24 |
+
import subprocess
|
25 |
+
import spaces
|
26 |
+
|
27 |
+
|
28 |
+
def load_frames(image: Image, mode='RGBA'):
|
29 |
+
return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
|
30 |
+
|
31 |
+
|
32 |
+
def run_setup():
|
33 |
+
try:
|
34 |
+
# Step 1: Install Git LFS
|
35 |
+
subprocess.run(["git", "lfs", "install"], check=True)
|
36 |
+
|
37 |
+
# Step 2: Clone the repository
|
38 |
+
repo_url = "https://huggingface.co/Hmrishav/t2v_sketch-lora"
|
39 |
+
subprocess.run(["git", "clone", repo_url], check=True)
|
40 |
+
|
41 |
+
# Step 3: Move the checkpoint file
|
42 |
+
source = "t2v_sketch-lora/checkpoint-2500"
|
43 |
+
destination = "./checkpoint-2500/"
|
44 |
+
os.rename(source, destination)
|
45 |
+
|
46 |
+
print("Setup completed successfully!")
|
47 |
+
except subprocess.CalledProcessError as e:
|
48 |
+
print(f"Error during setup: {e}")
|
49 |
+
except FileNotFoundError as e:
|
50 |
+
print(f"File operation error: {e}")
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Unexpected error: {e}")
|
53 |
+
|
54 |
+
# Automatically run setup during app initialization
|
55 |
+
run_setup()
|
56 |
+
|
57 |
+
|
58 |
+
def save_gif(frames, path):
|
59 |
+
imageio.mimsave(
|
60 |
+
path,
|
61 |
+
[frame.astype(np.uint8) for frame in frames],
|
62 |
+
format="GIF",
|
63 |
+
duration=1 / 10,
|
64 |
+
loop=0 # 0 means infinite loop
|
65 |
+
)
|
66 |
+
|
67 |
+
def load_image(imgname, target_size=None):
|
68 |
+
pil_img = Image.open(imgname).convert('RGB')
|
69 |
+
if target_size:
|
70 |
+
if isinstance(target_size, int):
|
71 |
+
target_size = (target_size, target_size)
|
72 |
+
pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
|
73 |
+
return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0)
|
74 |
+
|
75 |
+
def prepare_latents(pipe, x_aug):
|
76 |
+
with torch.cuda.amp.autocast():
|
77 |
+
batch_size, num_frames, channels, height, width = x_aug.shape
|
78 |
+
x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
|
79 |
+
latents = pipe.vae.encode(x_aug).latent_dist.sample()
|
80 |
+
latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
|
81 |
+
latents = latents.permute(0, 2, 1, 3, 4)
|
82 |
+
return pipe.vae.config.scaling_factor * latents
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
|
87 |
+
input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
|
88 |
+
input_img = torch.cat(input_img, dim=1)
|
89 |
+
latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
|
90 |
+
inv.set_timesteps(25)
|
91 |
+
id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
|
92 |
+
return torch.mean(id_latents, dim=2, keepdim=True)
|
93 |
+
|
94 |
+
def load_primary_models(pretrained_model_path):
|
95 |
+
return (
|
96 |
+
DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
|
97 |
+
CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
|
98 |
+
CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
|
99 |
+
AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
|
100 |
+
UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
|
101 |
+
)
|
102 |
+
|
103 |
+
def initialize_pipeline(model: str, device: str = "cuda"):
|
104 |
+
with warnings.catch_warnings():
|
105 |
+
warnings.simplefilter("ignore")
|
106 |
+
scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
|
107 |
+
pipe = TextToVideoSDPipeline.from_pretrained(
|
108 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
109 |
+
scheduler=scheduler,
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
|
112 |
+
vae=vae.to(device=device, dtype=torch.bfloat16),
|
113 |
+
unet=unet.to(device=device, dtype=torch.bfloat16),
|
114 |
+
)
|
115 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
116 |
+
return pipe, pipe.scheduler
|
117 |
+
|
118 |
+
# Initialize the models
|
119 |
+
LORA_CHECKPOINT = "checkpoint-2500"
|
120 |
+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
|
121 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
122 |
+
dtype = torch.bfloat16
|
123 |
+
|
124 |
+
pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
|
125 |
+
pipe = TextToVideoSDPipelineModded.from_pretrained(
|
126 |
+
pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
|
127 |
+
scheduler=pipe_inversion.scheduler,
|
128 |
+
tokenizer=pipe_inversion.tokenizer,
|
129 |
+
text_encoder=pipe_inversion.text_encoder,
|
130 |
+
vae=pipe_inversion.vae,
|
131 |
+
unet=pipe_inversion.unet,
|
132 |
+
).to(device)
|
133 |
+
|
134 |
+
@spaces.GPU(duration=100)
|
135 |
+
@torch.no_grad()
|
136 |
+
def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
|
137 |
+
pipe_inversion.to(device)
|
138 |
+
id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
|
139 |
+
latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
|
140 |
+
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
|
141 |
+
video_frames = pipe(
|
142 |
+
prompt=caption,
|
143 |
+
negative_prompt="",
|
144 |
+
num_frames=num_frames,
|
145 |
+
num_inference_steps=25,
|
146 |
+
inv_latents=latents,
|
147 |
+
guidance_scale=9,
|
148 |
+
generator=generator,
|
149 |
+
lambda_=lambda_,
|
150 |
+
).frames
|
151 |
+
|
152 |
+
gifs = []
|
153 |
+
for seed in range(num_seeds):
|
154 |
+
vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4"
|
155 |
+
gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif"
|
156 |
+
|
157 |
+
os.makedirs(os.path.dirname(vid_name), exist_ok=True)
|
158 |
+
os.makedirs(os.path.dirname(gif_name), exist_ok=True)
|
159 |
+
|
160 |
+
video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
|
161 |
+
VideoFileClip(vid_name).write_gif(gif_name)
|
162 |
+
|
163 |
+
with Image.open(gif_name) as im:
|
164 |
+
frames = load_frames(im)
|
165 |
+
|
166 |
+
frames_collect = np.empty((0, 1024, 1024), int)
|
167 |
+
for frame in frames:
|
168 |
+
frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
|
169 |
+
frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
|
170 |
+
_, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
171 |
+
frames_collect = np.append(frames_collect, [frame], axis=0)
|
172 |
+
|
173 |
+
save_gif(frames_collect, gif_name)
|
174 |
+
gifs.append(gif_name)
|
175 |
+
|
176 |
+
return gifs
|
177 |
+
|
178 |
+
def generate_output(image, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5) -> List[str]:
|
179 |
+
"""Main function to generate output GIFs"""
|
180 |
+
exp_dir = "static/app_tmp"
|
181 |
+
os.makedirs(exp_dir, exist_ok=True)
|
182 |
+
|
183 |
+
# Save the input image temporarily
|
184 |
+
temp_image_path = os.path.join(exp_dir, "temp_input.png")
|
185 |
+
image.save(temp_image_path)
|
186 |
+
|
187 |
+
# Generate the GIFs
|
188 |
+
generated_gifs = process_video(
|
189 |
+
num_frames=10,
|
190 |
+
num_seeds=num_seeds,
|
191 |
+
generator=None,
|
192 |
+
exp_dir=exp_dir,
|
193 |
+
load_name=temp_image_path,
|
194 |
+
caption=prompt,
|
195 |
+
lambda_=1 - lambda_value
|
196 |
+
)
|
197 |
+
|
198 |
+
# Apply filtering (assuming filter function is imported)
|
199 |
+
filtered_gifs = filter(generated_gifs, temp_image_path)
|
200 |
+
|
201 |
+
return filtered_gifs
|
202 |
+
|
203 |
+
|
204 |
+
def create_gradio_interface():
|
205 |
+
with gr.Blocks(css="""
|
206 |
+
.container {
|
207 |
+
max-width: 1200px;
|
208 |
+
margin: 0 auto;
|
209 |
+
padding: 20px;
|
210 |
+
}
|
211 |
+
.example-gallery {
|
212 |
+
margin: 20px 0;
|
213 |
+
padding: 20px;
|
214 |
+
background: #f7f7f7;
|
215 |
+
border-radius: 8px;
|
216 |
+
}
|
217 |
+
.selected-example {
|
218 |
+
margin: 20px 0;
|
219 |
+
padding: 20px;
|
220 |
+
background: #ffffff;
|
221 |
+
border-radius: 8px;
|
222 |
+
|
223 |
+
}
|
224 |
+
.controls-section {
|
225 |
+
background: #ffffff;
|
226 |
+
padding: 20px;
|
227 |
+
margin: 20px 0;
|
228 |
+
border-radius: 8px;
|
229 |
+
|
230 |
+
}
|
231 |
+
.output-gallery {
|
232 |
+
min-height: 500px;
|
233 |
+
margin: 20px 0;
|
234 |
+
padding: 20px;
|
235 |
+
background: #f7f7f7;
|
236 |
+
border-radius: 8px;
|
237 |
+
}
|
238 |
+
.example-item {
|
239 |
+
border-radius: 8px;
|
240 |
+
overflow: hidden;
|
241 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
242 |
+
transition: transform 0.2s;
|
243 |
+
cursor: pointer;
|
244 |
+
}
|
245 |
+
.example-item:hover {
|
246 |
+
transform: scale(1.05);
|
247 |
+
}
|
248 |
+
/* Prevent gallery images from expanding */
|
249 |
+
.gallery-image {
|
250 |
+
height: 200px !important;
|
251 |
+
width: 200px !important;
|
252 |
+
object-fit: cover !important;
|
253 |
+
}
|
254 |
+
.generate-btn {
|
255 |
+
width: 100%;
|
256 |
+
margin-top: 1rem;
|
257 |
+
}
|
258 |
+
|
259 |
+
.generate-btn:disabled {
|
260 |
+
opacity: 0.7;
|
261 |
+
cursor: not-allowed;
|
262 |
+
}
|
263 |
+
""") as demo:
|
264 |
+
gr.Markdown(
|
265 |
+
"""
|
266 |
+
|
267 |
+
<div align="center" id = "user-content-toc">
|
268 |
+
<img align="left" width="70" height="70" src="https://github.com/user-attachments/assets/c61cec76-3c4b-42eb-8c65-f07e0166b7d8" alt="">
|
269 |
+
|
270 |
+
# [FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations](https://hmrishavbandy.github.io/flipsketch-web/)
|
271 |
+
## [Hmrishav Bandyopadhyay](https://hmrishavbandy.github.io/) . [Yi-Zhe Song](https://personalpages.surrey.ac.uk/y.song/)
|
272 |
+
</div>
|
273 |
+
|
274 |
+
"""
|
275 |
+
)
|
276 |
+
|
277 |
+
with gr.Tabs() as tabs:
|
278 |
+
# First tab: Examples (Secure)
|
279 |
+
with gr.Tab("Examples"):
|
280 |
+
gr.Markdown("## Step 1 👉 Select a sketch from the gallery of sketches")
|
281 |
+
examples_dir = "static/examples"
|
282 |
+
if os.path.exists(examples_dir):
|
283 |
+
example_images = []
|
284 |
+
for example in os.listdir(examples_dir):
|
285 |
+
if example.endswith(('.png', '.jpg', '.jpeg')):
|
286 |
+
example_path = os.path.join(examples_dir, example)
|
287 |
+
example_images.append(Image.open(example_path))
|
288 |
+
|
289 |
+
example_selection = gr.Gallery(
|
290 |
+
example_images,
|
291 |
+
label="Sketch Gallery",
|
292 |
+
elem_classes="example-gallery",
|
293 |
+
columns=4,
|
294 |
+
rows=2,
|
295 |
+
height="auto",
|
296 |
+
allow_preview=False, # Disable preview expansion
|
297 |
+
show_share_button=False,
|
298 |
+
interactive=False,
|
299 |
+
selected_index=None # Don't pre-select any image
|
300 |
+
)
|
301 |
+
gr.Markdown("## Step 2 👉 Describe the motion you want to generate")
|
302 |
+
with gr.Group(elem_classes="selected-example"):
|
303 |
+
with gr.Row():
|
304 |
+
selected_example = gr.Image(
|
305 |
+
type="pil",
|
306 |
+
label="Selected Sketch",
|
307 |
+
scale=1,
|
308 |
+
interactive=False,
|
309 |
+
show_download_button=False,
|
310 |
+
height=300 # Fixed height for consistency
|
311 |
+
)
|
312 |
+
with gr.Column(scale=2):
|
313 |
+
example_prompt = gr.Textbox(
|
314 |
+
label="Prompt",
|
315 |
+
placeholder="Describe the motion...",
|
316 |
+
lines=3
|
317 |
+
)
|
318 |
+
with gr.Row():
|
319 |
+
example_num_seeds = gr.Slider(
|
320 |
+
minimum=1,
|
321 |
+
maximum=10,
|
322 |
+
value=5,
|
323 |
+
step=1,
|
324 |
+
label="Seeds"
|
325 |
+
)
|
326 |
+
example_lambda = gr.Slider(
|
327 |
+
minimum=0,
|
328 |
+
maximum=1,
|
329 |
+
value=0.5,
|
330 |
+
step=0.1,
|
331 |
+
label="Motion Strength"
|
332 |
+
)
|
333 |
+
example_generate_btn = gr.Button(
|
334 |
+
"Generate Animation",
|
335 |
+
variant="primary",
|
336 |
+
elem_classes="generate-btn",
|
337 |
+
interactive=True,
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
gr.Markdown("## Result 👉 Generated Animations ❤️")
|
343 |
+
example_gallery = gr.Gallery(
|
344 |
+
label="Results",
|
345 |
+
elem_classes="output-gallery",
|
346 |
+
columns=3,
|
347 |
+
rows=2,
|
348 |
+
height="auto",
|
349 |
+
allow_preview=False, # Disable preview expansion
|
350 |
+
show_share_button=False,
|
351 |
+
object_fit="cover",
|
352 |
+
preview=False
|
353 |
+
)
|
354 |
+
|
355 |
+
# Second tab: Upload
|
356 |
+
with gr.Tab("Upload Your Sketch"):
|
357 |
+
with gr.Group(elem_classes="selected-example"):
|
358 |
+
with gr.Row():
|
359 |
+
upload_image = gr.Image(
|
360 |
+
type="pil",
|
361 |
+
label="Upload Your Sketch",
|
362 |
+
scale=1,
|
363 |
+
height=300, # Fixed height for consistency
|
364 |
+
show_download_button=False,
|
365 |
+
sources=["upload"],
|
366 |
+
)
|
367 |
+
with gr.Column(scale=2):
|
368 |
+
upload_prompt = gr.Textbox(
|
369 |
+
label="Prompt",
|
370 |
+
placeholder="Describe what you want to generate...",
|
371 |
+
lines=3
|
372 |
+
)
|
373 |
+
with gr.Row():
|
374 |
+
upload_num_seeds = gr.Slider(
|
375 |
+
minimum=1,
|
376 |
+
maximum=10,
|
377 |
+
value=5,
|
378 |
+
step=1,
|
379 |
+
label="Number of Variations"
|
380 |
+
)
|
381 |
+
upload_lambda = gr.Slider(
|
382 |
+
minimum=0,
|
383 |
+
maximum=1,
|
384 |
+
value=0.5,
|
385 |
+
step=0.1,
|
386 |
+
label="Motion Strength"
|
387 |
+
)
|
388 |
+
upload_generate_btn = gr.Button(
|
389 |
+
"Generate Animation",
|
390 |
+
variant="primary",
|
391 |
+
elem_classes="generate-btn",
|
392 |
+
size="lg",
|
393 |
+
interactive=True,
|
394 |
+
)
|
395 |
+
|
396 |
+
gr.Markdown("## Result 👉 Generated Animations ❤️")
|
397 |
+
upload_gallery = gr.Gallery(
|
398 |
+
label="Results",
|
399 |
+
elem_classes="output-gallery",
|
400 |
+
columns=3,
|
401 |
+
rows=2,
|
402 |
+
height="auto",
|
403 |
+
allow_preview=False, # Disable preview expansion
|
404 |
+
show_share_button=False,
|
405 |
+
object_fit="cover",
|
406 |
+
preview=False
|
407 |
+
)
|
408 |
+
|
409 |
+
# Event handlers
|
410 |
+
def select_example(evt: gr.SelectData):
|
411 |
+
prompts = {'sketch1.png': 'The camel walks slowly',
|
412 |
+
'sketch2.png': 'The wine in the wine glass sways from side to side',
|
413 |
+
'sketch3.png': 'The squirrel is eating a nut',
|
414 |
+
'sketch4.png': 'The surfer surfs on the waves',
|
415 |
+
'sketch5.png': 'A galloping horse',
|
416 |
+
'sketch6.png': 'The cat walks forward',
|
417 |
+
'sketch7.png': 'The eagle flies in the sky',
|
418 |
+
'sketch8.png': 'The flower is blooming slowly',
|
419 |
+
'sketch9.png': 'The reindeer looks around',
|
420 |
+
'sketch10.png': 'The cloud floats in the sky',
|
421 |
+
'sketch11.png': 'The jazz saxophonist performs on stage with a rhythmic sway, his upper body sways subtly to the rhythm of the music.',
|
422 |
+
'sketch12.png': 'The biker rides on the road',}
|
423 |
+
if evt.index < len(example_images):
|
424 |
+
example_img = example_images[evt.index]
|
425 |
+
prompt_text = prompts.get(os.path.basename(example_img.filename), "")
|
426 |
+
|
427 |
+
|
428 |
+
return [
|
429 |
+
example_img,
|
430 |
+
prompt_text
|
431 |
+
]
|
432 |
+
return [None, ""]
|
433 |
+
|
434 |
+
example_selection.select(
|
435 |
+
select_example,
|
436 |
+
None,
|
437 |
+
[selected_example, example_prompt]
|
438 |
+
)
|
439 |
+
|
440 |
+
example_generate_btn.click(
|
441 |
+
fn=generate_output,
|
442 |
+
inputs=[
|
443 |
+
selected_example,
|
444 |
+
example_prompt,
|
445 |
+
example_num_seeds,
|
446 |
+
example_lambda
|
447 |
+
],
|
448 |
+
outputs=example_gallery
|
449 |
+
)
|
450 |
+
|
451 |
+
upload_generate_btn.click(
|
452 |
+
fn=generate_output,
|
453 |
+
inputs=[
|
454 |
+
upload_image,
|
455 |
+
upload_prompt,
|
456 |
+
upload_num_seeds,
|
457 |
+
upload_lambda
|
458 |
+
],
|
459 |
+
outputs=upload_gallery
|
460 |
+
)
|
461 |
+
|
462 |
+
return demo
|
463 |
+
|
464 |
+
# Launch the app
|
465 |
+
if __name__ == "__main__":
|
466 |
+
demo = create_gradio_interface()
|
467 |
+
demo.launch(
|
468 |
+
server_name="0.0.0.0",
|
469 |
+
server_port=7860,
|
470 |
+
show_api=False
|
471 |
+
)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
gunicorn
|
|
|
2 |
accelerate==0.29.2
|
3 |
blinker==1.9.0
|
4 |
certifi==2024.8.30
|
|
|
1 |
gunicorn
|
2 |
+
spaces
|
3 |
accelerate==0.29.2
|
4 |
blinker==1.9.0
|
5 |
certifi==2024.8.30
|
static/app_tmp/temp_input.png
ADDED