Hmrishav commited on
Commit
e79d24a
1 Parent(s): ca282ad

resolve deps

Browse files
Files changed (5) hide show
  1. .gitignore +7 -0
  2. README.md +1 -1
  3. app_gradio.py +471 -0
  4. requirements.txt +1 -0
  5. static/app_tmp/temp_input.png +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ checkpoint-2500/
2
+ t2v_sketch-lora/
3
+ __pycache__/
4
+ static/app_tmp/gif_logs/*
5
+ static/app_tmp/mp4_logs/*
6
+ static/app_tmp/png_logs/*
7
+ static/uploads/*
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- app_file: app.py
8
  pinned: false
9
  ---
10
 
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ app_file: app_gradio.py
8
  pinned: false
9
  ---
10
 
app_gradio.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gradio as gr
5
+ import torchvision
6
+ import warnings
7
+ import numpy as np
8
+ from PIL import Image, ImageSequence
9
+ from moviepy.editor import VideoFileClip
10
+ import imageio
11
+ from diffusers import (
12
+ TextToVideoSDPipeline,
13
+ AutoencoderKL,
14
+ DDPMScheduler,
15
+ DDIMScheduler,
16
+ UNet3DConditionModel,
17
+ )
18
+ from transformers import CLIPTokenizer, CLIPTextModel
19
+ from diffusers.utils import export_to_video
20
+ from typing import List
21
+ from text2vid_modded import TextToVideoSDPipelineModded
22
+ from invert_utils import ddim_inversion as dd_inversion
23
+ from gifs_filter import filter
24
+ import subprocess
25
+ import spaces
26
+
27
+
28
+ def load_frames(image: Image, mode='RGBA'):
29
+ return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)])
30
+
31
+
32
+ def run_setup():
33
+ try:
34
+ # Step 1: Install Git LFS
35
+ subprocess.run(["git", "lfs", "install"], check=True)
36
+
37
+ # Step 2: Clone the repository
38
+ repo_url = "https://huggingface.co/Hmrishav/t2v_sketch-lora"
39
+ subprocess.run(["git", "clone", repo_url], check=True)
40
+
41
+ # Step 3: Move the checkpoint file
42
+ source = "t2v_sketch-lora/checkpoint-2500"
43
+ destination = "./checkpoint-2500/"
44
+ os.rename(source, destination)
45
+
46
+ print("Setup completed successfully!")
47
+ except subprocess.CalledProcessError as e:
48
+ print(f"Error during setup: {e}")
49
+ except FileNotFoundError as e:
50
+ print(f"File operation error: {e}")
51
+ except Exception as e:
52
+ print(f"Unexpected error: {e}")
53
+
54
+ # Automatically run setup during app initialization
55
+ run_setup()
56
+
57
+
58
+ def save_gif(frames, path):
59
+ imageio.mimsave(
60
+ path,
61
+ [frame.astype(np.uint8) for frame in frames],
62
+ format="GIF",
63
+ duration=1 / 10,
64
+ loop=0 # 0 means infinite loop
65
+ )
66
+
67
+ def load_image(imgname, target_size=None):
68
+ pil_img = Image.open(imgname).convert('RGB')
69
+ if target_size:
70
+ if isinstance(target_size, int):
71
+ target_size = (target_size, target_size)
72
+ pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS)
73
+ return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0)
74
+
75
+ def prepare_latents(pipe, x_aug):
76
+ with torch.cuda.amp.autocast():
77
+ batch_size, num_frames, channels, height, width = x_aug.shape
78
+ x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width)
79
+ latents = pipe.vae.encode(x_aug).latent_dist.sample()
80
+ latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3])
81
+ latents = latents.permute(0, 2, 1, 3, 4)
82
+ return pipe.vae.config.scaling_factor * latents
83
+
84
+
85
+ @torch.no_grad()
86
+ def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16):
87
+ input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5
88
+ input_img = torch.cat(input_img, dim=1)
89
+ latents = prepare_latents(pipe, input_img).to(torch.bfloat16)
90
+ inv.set_timesteps(25)
91
+ id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype)
92
+ return torch.mean(id_latents, dim=2, keepdim=True)
93
+
94
+ def load_primary_models(pretrained_model_path):
95
+ return (
96
+ DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"),
97
+ CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"),
98
+ CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"),
99
+ AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"),
100
+ UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"),
101
+ )
102
+
103
+ def initialize_pipeline(model: str, device: str = "cuda"):
104
+ with warnings.catch_warnings():
105
+ warnings.simplefilter("ignore")
106
+ scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model)
107
+ pipe = TextToVideoSDPipeline.from_pretrained(
108
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
109
+ scheduler=scheduler,
110
+ tokenizer=tokenizer,
111
+ text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16),
112
+ vae=vae.to(device=device, dtype=torch.bfloat16),
113
+ unet=unet.to(device=device, dtype=torch.bfloat16),
114
+ )
115
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
116
+ return pipe, pipe.scheduler
117
+
118
+ # Initialize the models
119
+ LORA_CHECKPOINT = "checkpoint-2500"
120
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
121
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
122
+ dtype = torch.bfloat16
123
+
124
+ pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device)
125
+ pipe = TextToVideoSDPipelineModded.from_pretrained(
126
+ pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b",
127
+ scheduler=pipe_inversion.scheduler,
128
+ tokenizer=pipe_inversion.tokenizer,
129
+ text_encoder=pipe_inversion.text_encoder,
130
+ vae=pipe_inversion.vae,
131
+ unet=pipe_inversion.unet,
132
+ ).to(device)
133
+
134
+ @spaces.GPU(duration=100)
135
+ @torch.no_grad()
136
+ def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_):
137
+ pipe_inversion.to(device)
138
+ id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype)
139
+ latents = id_latents.repeat(num_seeds, 1, 1, 1, 1)
140
+ generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)]
141
+ video_frames = pipe(
142
+ prompt=caption,
143
+ negative_prompt="",
144
+ num_frames=num_frames,
145
+ num_inference_steps=25,
146
+ inv_latents=latents,
147
+ guidance_scale=9,
148
+ generator=generator,
149
+ lambda_=lambda_,
150
+ ).frames
151
+
152
+ gifs = []
153
+ for seed in range(num_seeds):
154
+ vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4"
155
+ gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif"
156
+
157
+ os.makedirs(os.path.dirname(vid_name), exist_ok=True)
158
+ os.makedirs(os.path.dirname(gif_name), exist_ok=True)
159
+
160
+ video_path = export_to_video(video_frames[seed], output_video_path=vid_name)
161
+ VideoFileClip(vid_name).write_gif(gif_name)
162
+
163
+ with Image.open(gif_name) as im:
164
+ frames = load_frames(im)
165
+
166
+ frames_collect = np.empty((0, 1024, 1024), int)
167
+ for frame in frames:
168
+ frame = cv2.resize(frame, (1024, 1024))[:, :, :3]
169
+ frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY)
170
+ _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
171
+ frames_collect = np.append(frames_collect, [frame], axis=0)
172
+
173
+ save_gif(frames_collect, gif_name)
174
+ gifs.append(gif_name)
175
+
176
+ return gifs
177
+
178
+ def generate_output(image, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5) -> List[str]:
179
+ """Main function to generate output GIFs"""
180
+ exp_dir = "static/app_tmp"
181
+ os.makedirs(exp_dir, exist_ok=True)
182
+
183
+ # Save the input image temporarily
184
+ temp_image_path = os.path.join(exp_dir, "temp_input.png")
185
+ image.save(temp_image_path)
186
+
187
+ # Generate the GIFs
188
+ generated_gifs = process_video(
189
+ num_frames=10,
190
+ num_seeds=num_seeds,
191
+ generator=None,
192
+ exp_dir=exp_dir,
193
+ load_name=temp_image_path,
194
+ caption=prompt,
195
+ lambda_=1 - lambda_value
196
+ )
197
+
198
+ # Apply filtering (assuming filter function is imported)
199
+ filtered_gifs = filter(generated_gifs, temp_image_path)
200
+
201
+ return filtered_gifs
202
+
203
+
204
+ def create_gradio_interface():
205
+ with gr.Blocks(css="""
206
+ .container {
207
+ max-width: 1200px;
208
+ margin: 0 auto;
209
+ padding: 20px;
210
+ }
211
+ .example-gallery {
212
+ margin: 20px 0;
213
+ padding: 20px;
214
+ background: #f7f7f7;
215
+ border-radius: 8px;
216
+ }
217
+ .selected-example {
218
+ margin: 20px 0;
219
+ padding: 20px;
220
+ background: #ffffff;
221
+ border-radius: 8px;
222
+
223
+ }
224
+ .controls-section {
225
+ background: #ffffff;
226
+ padding: 20px;
227
+ margin: 20px 0;
228
+ border-radius: 8px;
229
+
230
+ }
231
+ .output-gallery {
232
+ min-height: 500px;
233
+ margin: 20px 0;
234
+ padding: 20px;
235
+ background: #f7f7f7;
236
+ border-radius: 8px;
237
+ }
238
+ .example-item {
239
+ border-radius: 8px;
240
+ overflow: hidden;
241
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
242
+ transition: transform 0.2s;
243
+ cursor: pointer;
244
+ }
245
+ .example-item:hover {
246
+ transform: scale(1.05);
247
+ }
248
+ /* Prevent gallery images from expanding */
249
+ .gallery-image {
250
+ height: 200px !important;
251
+ width: 200px !important;
252
+ object-fit: cover !important;
253
+ }
254
+ .generate-btn {
255
+ width: 100%;
256
+ margin-top: 1rem;
257
+ }
258
+
259
+ .generate-btn:disabled {
260
+ opacity: 0.7;
261
+ cursor: not-allowed;
262
+ }
263
+ """) as demo:
264
+ gr.Markdown(
265
+ """
266
+
267
+ <div align="center" id = "user-content-toc">
268
+ <img align="left" width="70" height="70" src="https://github.com/user-attachments/assets/c61cec76-3c4b-42eb-8c65-f07e0166b7d8" alt="">
269
+
270
+ # [FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations](https://hmrishavbandy.github.io/flipsketch-web/)
271
+ ## [Hmrishav Bandyopadhyay](https://hmrishavbandy.github.io/) . [Yi-Zhe Song](https://personalpages.surrey.ac.uk/y.song/)
272
+ </div>
273
+
274
+ """
275
+ )
276
+
277
+ with gr.Tabs() as tabs:
278
+ # First tab: Examples (Secure)
279
+ with gr.Tab("Examples"):
280
+ gr.Markdown("## Step 1 👉 &nbsp; &nbsp; &nbsp; Select a sketch from the gallery of sketches")
281
+ examples_dir = "static/examples"
282
+ if os.path.exists(examples_dir):
283
+ example_images = []
284
+ for example in os.listdir(examples_dir):
285
+ if example.endswith(('.png', '.jpg', '.jpeg')):
286
+ example_path = os.path.join(examples_dir, example)
287
+ example_images.append(Image.open(example_path))
288
+
289
+ example_selection = gr.Gallery(
290
+ example_images,
291
+ label="Sketch Gallery",
292
+ elem_classes="example-gallery",
293
+ columns=4,
294
+ rows=2,
295
+ height="auto",
296
+ allow_preview=False, # Disable preview expansion
297
+ show_share_button=False,
298
+ interactive=False,
299
+ selected_index=None # Don't pre-select any image
300
+ )
301
+ gr.Markdown("## Step 2 👉 &nbsp; &nbsp; &nbsp; Describe the motion you want to generate")
302
+ with gr.Group(elem_classes="selected-example"):
303
+ with gr.Row():
304
+ selected_example = gr.Image(
305
+ type="pil",
306
+ label="Selected Sketch",
307
+ scale=1,
308
+ interactive=False,
309
+ show_download_button=False,
310
+ height=300 # Fixed height for consistency
311
+ )
312
+ with gr.Column(scale=2):
313
+ example_prompt = gr.Textbox(
314
+ label="Prompt",
315
+ placeholder="Describe the motion...",
316
+ lines=3
317
+ )
318
+ with gr.Row():
319
+ example_num_seeds = gr.Slider(
320
+ minimum=1,
321
+ maximum=10,
322
+ value=5,
323
+ step=1,
324
+ label="Seeds"
325
+ )
326
+ example_lambda = gr.Slider(
327
+ minimum=0,
328
+ maximum=1,
329
+ value=0.5,
330
+ step=0.1,
331
+ label="Motion Strength"
332
+ )
333
+ example_generate_btn = gr.Button(
334
+ "Generate Animation",
335
+ variant="primary",
336
+ elem_classes="generate-btn",
337
+ interactive=True,
338
+ )
339
+
340
+
341
+
342
+ gr.Markdown("## Result 👉 &nbsp; &nbsp; &nbsp; Generated Animations ❤️")
343
+ example_gallery = gr.Gallery(
344
+ label="Results",
345
+ elem_classes="output-gallery",
346
+ columns=3,
347
+ rows=2,
348
+ height="auto",
349
+ allow_preview=False, # Disable preview expansion
350
+ show_share_button=False,
351
+ object_fit="cover",
352
+ preview=False
353
+ )
354
+
355
+ # Second tab: Upload
356
+ with gr.Tab("Upload Your Sketch"):
357
+ with gr.Group(elem_classes="selected-example"):
358
+ with gr.Row():
359
+ upload_image = gr.Image(
360
+ type="pil",
361
+ label="Upload Your Sketch",
362
+ scale=1,
363
+ height=300, # Fixed height for consistency
364
+ show_download_button=False,
365
+ sources=["upload"],
366
+ )
367
+ with gr.Column(scale=2):
368
+ upload_prompt = gr.Textbox(
369
+ label="Prompt",
370
+ placeholder="Describe what you want to generate...",
371
+ lines=3
372
+ )
373
+ with gr.Row():
374
+ upload_num_seeds = gr.Slider(
375
+ minimum=1,
376
+ maximum=10,
377
+ value=5,
378
+ step=1,
379
+ label="Number of Variations"
380
+ )
381
+ upload_lambda = gr.Slider(
382
+ minimum=0,
383
+ maximum=1,
384
+ value=0.5,
385
+ step=0.1,
386
+ label="Motion Strength"
387
+ )
388
+ upload_generate_btn = gr.Button(
389
+ "Generate Animation",
390
+ variant="primary",
391
+ elem_classes="generate-btn",
392
+ size="lg",
393
+ interactive=True,
394
+ )
395
+
396
+ gr.Markdown("## Result 👉 &nbsp; &nbsp; &nbsp; Generated Animations ❤️")
397
+ upload_gallery = gr.Gallery(
398
+ label="Results",
399
+ elem_classes="output-gallery",
400
+ columns=3,
401
+ rows=2,
402
+ height="auto",
403
+ allow_preview=False, # Disable preview expansion
404
+ show_share_button=False,
405
+ object_fit="cover",
406
+ preview=False
407
+ )
408
+
409
+ # Event handlers
410
+ def select_example(evt: gr.SelectData):
411
+ prompts = {'sketch1.png': 'The camel walks slowly',
412
+ 'sketch2.png': 'The wine in the wine glass sways from side to side',
413
+ 'sketch3.png': 'The squirrel is eating a nut',
414
+ 'sketch4.png': 'The surfer surfs on the waves',
415
+ 'sketch5.png': 'A galloping horse',
416
+ 'sketch6.png': 'The cat walks forward',
417
+ 'sketch7.png': 'The eagle flies in the sky',
418
+ 'sketch8.png': 'The flower is blooming slowly',
419
+ 'sketch9.png': 'The reindeer looks around',
420
+ 'sketch10.png': 'The cloud floats in the sky',
421
+ 'sketch11.png': 'The jazz saxophonist performs on stage with a rhythmic sway, his upper body sways subtly to the rhythm of the music.',
422
+ 'sketch12.png': 'The biker rides on the road',}
423
+ if evt.index < len(example_images):
424
+ example_img = example_images[evt.index]
425
+ prompt_text = prompts.get(os.path.basename(example_img.filename), "")
426
+
427
+
428
+ return [
429
+ example_img,
430
+ prompt_text
431
+ ]
432
+ return [None, ""]
433
+
434
+ example_selection.select(
435
+ select_example,
436
+ None,
437
+ [selected_example, example_prompt]
438
+ )
439
+
440
+ example_generate_btn.click(
441
+ fn=generate_output,
442
+ inputs=[
443
+ selected_example,
444
+ example_prompt,
445
+ example_num_seeds,
446
+ example_lambda
447
+ ],
448
+ outputs=example_gallery
449
+ )
450
+
451
+ upload_generate_btn.click(
452
+ fn=generate_output,
453
+ inputs=[
454
+ upload_image,
455
+ upload_prompt,
456
+ upload_num_seeds,
457
+ upload_lambda
458
+ ],
459
+ outputs=upload_gallery
460
+ )
461
+
462
+ return demo
463
+
464
+ # Launch the app
465
+ if __name__ == "__main__":
466
+ demo = create_gradio_interface()
467
+ demo.launch(
468
+ server_name="0.0.0.0",
469
+ server_port=7860,
470
+ show_api=False
471
+ )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gunicorn
 
2
  accelerate==0.29.2
3
  blinker==1.9.0
4
  certifi==2024.8.30
 
1
  gunicorn
2
+ spaces
3
  accelerate==0.29.2
4
  blinker==1.9.0
5
  certifi==2024.8.30
static/app_tmp/temp_input.png ADDED