Yw22 commited on
Commit
a9ceb51
·
1 Parent(s): 112b465
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
__asset__/jellyfish.mp4 ADDED
Binary file (536 kB). View file
 
__asset__/lush.mp4 ADDED
Binary file (546 kB). View file
 
__asset__/painting.mp4 ADDED
Binary file (980 kB). View file
 
__asset__/rose.mp4 ADDED
Binary file (334 kB). View file
 
__asset__/turtle.mp4 ADDED
Binary file (583 kB). View file
 
__asset__/tusun.mp4 ADDED
Binary file (557 kB). View file
 
app-12.py DELETED
@@ -1,677 +0,0 @@
1
- import os
2
- import sys
3
-
4
-
5
- print("Installing correct gradio version...")
6
- os.system("pip uninstall -y gradio")
7
- os.system("pip install gradio==4.7.0")
8
- print("Installing Finished!")
9
-
10
-
11
- import gradio as gr
12
- import numpy as np
13
- import cv2
14
- import uuid
15
- import torch
16
- import torchvision
17
- import json
18
- import spaces
19
-
20
- from PIL import Image
21
- from omegaconf import OmegaConf
22
- from einops import rearrange, repeat
23
- from torchvision import transforms,utils
24
- from transformers import CLIPTextModel, CLIPTokenizer
25
- from diffusers import AutoencoderKL, DDIMScheduler
26
-
27
- from pipelines.pipeline_imagecoductor import ImageConductorPipeline
28
- from modules.unet import UNet3DConditionFlowModel
29
- from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil
30
- from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian, save_videos_grid
31
- from utils.lora_utils import add_LoRA_to_controlnet
32
- from utils.visualizer import Visualizer, vis_flow_to_video
33
- #### Description ####
34
- title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
35
-
36
- head = r"""
37
- <div style="text-align: center;">
38
- <h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
39
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
40
- <a href=""></a>
41
- <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
42
- <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
43
- <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
44
-
45
-
46
- </div>
47
- </br>
48
- </div>
49
- """
50
-
51
-
52
-
53
- descriptions = r"""
54
- Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
55
- 🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
56
- """
57
-
58
-
59
- instructions = r"""
60
- - ⭐️ <b>step1: </b>Upload or select one image from Example.
61
- - ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
62
- - ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
63
- - ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
64
- - ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
65
- - ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
66
- """
67
-
68
- citation = r"""
69
- If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
70
- [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
71
- ---
72
-
73
- 📝 **Citation**
74
- <br>
75
- If our work is useful for your research, please consider citing:
76
- ```bibtex
77
- @misc{li2024imageconductor,
78
- title={Image Conductor: Precision Control for Interactive Video Synthesis},
79
- author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
80
- year={2024},
81
- eprint={2406.15339},
82
- archivePrefix={arXiv},
83
- primaryClass={cs.CV}
84
- }
85
- ```
86
-
87
- 📧 **Contact**
88
- <br>
89
- If you have any questions, please feel free to reach me out at <b>ywl@stu.pku.edu.cn</b>.
90
-
91
- # """
92
-
93
- os.makedirs("models/personalized")
94
- os.makedirs("models/sd1-5")
95
-
96
- if not os.path.exists("models/flow_controlnet.ckpt"):
97
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/')
98
- os.system(f'mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt')
99
- print("flow_controlnet Download!", )
100
-
101
- if not os.path.exists("models/image_controlnet.ckpt"):
102
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/')
103
- os.system(f'mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt')
104
- print("image_controlnet Download!", )
105
-
106
- if not os.path.exists("models/unet.ckpt"):
107
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/unet.ckpt?download=true -P models/')
108
- os.system(f'mv models/unet.ckpt?download=true models/unet.ckpt')
109
- print("unet Download!", )
110
-
111
-
112
- if not os.path.exists("models/sd1-5/config.json"):
113
- os.system(f'wget -q https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json?download=true -P models/sd1-5/')
114
- os.system(f'mv models/sd1-5/config.json?download=true models/sd1-5/config.json')
115
- print("config Download!", )
116
-
117
-
118
- if not os.path.exists("models/sd1-5/unet.ckpt"):
119
- os.system(f'cp -r models/unet.ckpt models/sd1-5/unet.ckpt')
120
-
121
- # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
122
-
123
- if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
124
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized')
125
- os.system(f'mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors')
126
- print("helloobjects_V12c Download!", )
127
-
128
-
129
- if not os.path.exists("models/personalized/TUSUN.safetensors"):
130
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized')
131
- os.system(f'mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors')
132
- print("TUSUN Download!", )
133
-
134
- # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
135
- # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
136
-
137
-
138
- # # 检查命令是否成功
139
- # if mv1 == 0 and mv2 == 0:
140
- # print("file move success!")
141
- # else:
142
- # print("file move failed!")
143
-
144
-
145
- # - - - - - examples - - - - - #
146
-
147
- image_examples = [
148
- ["__asset__/images/object/turtle-1.jpg",
149
- "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
150
- "object",
151
- 11318446767408804497,
152
- "",
153
- "turtle"
154
- ],
155
-
156
- ["__asset__/images/object/rose-1.jpg",
157
- "a red rose engulfed in flames.",
158
- "object",
159
- 6854275249656120509,
160
- "",
161
- "rose",
162
- ],
163
-
164
- ["__asset__/images/object/jellyfish-1.jpg",
165
- "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
166
- "object",
167
- 17966188172968903484,
168
- "HelloObject",
169
- "jellyfish"
170
- ],
171
-
172
-
173
- ["__asset__/images/camera/lush-1.jpg",
174
- "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
175
- "camera",
176
- 7970487946960948963,
177
- "HelloObject",
178
- "lush",
179
- ],
180
-
181
- ["__asset__/images/camera/tusun-1.jpg",
182
- "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
183
- "camera",
184
- 996953226890228361,
185
- "TUSUN",
186
- "tusun",
187
- ],
188
-
189
- ["__asset__/images/camera/painting-1.jpg",
190
- "A oil painting.",
191
- "camera",
192
- 16867854766769816385,
193
- "",
194
- "painting"
195
- ],
196
- ]
197
-
198
-
199
- POINTS = {
200
- 'turtle': "__asset__/trajs/object/turtle-1.json",
201
- 'rose': "__asset__/trajs/object/rose-1.json",
202
- 'jellyfish': "__asset__/trajs/object/jellyfish-1.json",
203
- 'lush': "__asset__/trajs/camera/lush-1.json",
204
- 'tusun': "__asset__/trajs/camera/tusun-1.json",
205
- 'painting': "__asset__/trajs/camera/painting-1.json",
206
- }
207
-
208
- IMAGE_PATH = {
209
- 'turtle': "__asset__/images/object/turtle-1.jpg",
210
- 'rose': "__asset__/images/object/rose-1.jpg",
211
- 'jellyfish': "__asset__/images/object/jellyfish-1.jpg",
212
- 'lush': "__asset__/images/camera/lush-1.jpg",
213
- 'tusun': "__asset__/images/camera/tusun-1.jpg",
214
- 'painting': "__asset__/images/camera/painting-1.jpg",
215
- }
216
-
217
-
218
-
219
- DREAM_BOOTH = {
220
- 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
221
- }
222
-
223
- LORA = {
224
- 'TUSUN': 'models/personalized/TUSUN.safetensors',
225
- }
226
-
227
- LORA_ALPHA = {
228
- 'TUSUN': 0.6,
229
- }
230
-
231
- NPROMPT = {
232
- "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
233
- }
234
-
235
- output_dir = "outputs"
236
- ensure_dirname(output_dir)
237
-
238
- def points_to_flows(track_points, model_length, height, width):
239
- input_drag = np.zeros((model_length - 1, height, width, 2))
240
- for splited_track in track_points:
241
- if len(splited_track) == 1: # stationary point
242
- displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
243
- splited_track = tuple([splited_track[0], displacement_point])
244
- # interpolate the track
245
- splited_track = interpolate_trajectory(splited_track, model_length)
246
- splited_track = splited_track[:model_length]
247
- if len(splited_track) < model_length:
248
- splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
249
- for i in range(model_length - 1):
250
- start_point = splited_track[i]
251
- end_point = splited_track[i+1]
252
- input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
253
- input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
254
- return input_drag
255
-
256
- class ImageConductor:
257
- def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
258
- self.device = device
259
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
260
- text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
261
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
262
- inference_config = OmegaConf.load("configs/inference/inference.yaml")
263
- unet = UNet3DConditionFlowModel.from_pretrained_2d("models/sd1-5/", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
264
-
265
- self.vae = vae
266
-
267
- ### >>> Initialize UNet module >>> ###
268
- load_model(unet, unet_path)
269
-
270
- ### >>> Initialize image controlnet module >>> ###
271
- image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
272
- load_model(image_controlnet, image_controlnet_path)
273
- ### >>> Initialize flow controlnet module >>> ###
274
- flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
275
- add_LoRA_to_controlnet(lora_rank, flow_controlnet)
276
- load_model(flow_controlnet, flow_controlnet_path)
277
-
278
- unet.eval().to(device)
279
- image_controlnet.eval().to(device)
280
- flow_controlnet.eval().to(device)
281
-
282
- self.pipeline = ImageConductorPipeline(
283
- unet=unet,
284
- vae=vae,
285
- tokenizer=tokenizer,
286
- text_encoder=text_encoder,
287
- scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
288
- image_controlnet=image_controlnet,
289
- flow_controlnet=flow_controlnet,
290
- ).to(device)
291
-
292
-
293
- self.height = height
294
- self.width = width
295
- # _, model_step, _ = split_filename(model_path)
296
- # self.ouput_prefix = f'{model_step}_{width}X{height}'
297
- self.model_length = model_length
298
-
299
- blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
300
-
301
- self.blur_kernel = blur_kernel
302
-
303
- @spaces.GPU(duration=180)
304
- def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type):
305
- print("Run!")
306
- if examples_type != "":
307
- ### for adapting high version gradio
308
- tracking_points = gr.State([])
309
- first_frame_path = IMAGE_PATH[examples_type]
310
- points = json.load(open(POINTS[examples_type]))
311
- tracking_points.value.extend(points)
312
- print("example first_frame_path", first_frame_path)
313
- print("example tracking_points", tracking_points.value)
314
-
315
- original_width, original_height=384, 256
316
- if isinstance(tracking_points, list):
317
- input_all_points = tracking_points
318
- else:
319
- input_all_points = tracking_points.value
320
-
321
- print("input_all_points", input_all_points)
322
- resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
323
-
324
- dir, base, ext = split_filename(first_frame_path)
325
- id = base.split('_')[-1]
326
-
327
-
328
- visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
329
-
330
- ## image condition
331
- image_transforms = transforms.Compose([
332
- transforms.RandomResizedCrop(
333
- (self.height, self.width), (1.0, 1.0),
334
- ratio=(self.width/self.height, self.width/self.height)
335
- ),
336
- transforms.ToTensor(),
337
- ])
338
-
339
- image_paths = [first_frame_path]
340
- controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
341
- controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
342
- controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
343
- num_controlnet_images = controlnet_images.shape[2]
344
- controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
345
- self.vae.to(device)
346
- controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
347
- controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
348
-
349
- # flow condition
350
- controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
351
- for i in range(0, self.model_length-1):
352
- controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
353
- controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
354
- os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
355
- trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
356
- torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
357
- controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, :self.model_length, ...]
358
- controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").float().to(device)
359
-
360
- dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
361
- lora_model_path = LORA.get(personalized, '')
362
- lora_alpha = LORA_ALPHA.get(personalized, 0.6)
363
- self.pipeline = load_weights(
364
- self.pipeline,
365
- dreambooth_model_path = dreambooth_model_path,
366
- lora_model_path = lora_model_path,
367
- lora_alpha = lora_alpha,
368
- ).to(device)
369
-
370
- if NPROMPT.get(personalized, '') != '':
371
- negative_prompt = NPROMPT.get(personalized)
372
-
373
- if randomize_seed:
374
- random_seed = torch.seed()
375
- else:
376
- seed = int(seed)
377
- random_seed = seed
378
- torch.manual_seed(random_seed)
379
- torch.cuda.manual_seed_all(random_seed)
380
- print(f"current seed: {torch.initial_seed()}")
381
- sample = self.pipeline(
382
- prompt,
383
- negative_prompt = negative_prompt,
384
- num_inference_steps = num_inference_steps,
385
- guidance_scale = guidance_scale,
386
- width = self.width,
387
- height = self.height,
388
- video_length = self.model_length,
389
- controlnet_images = controlnet_images, # 1 4 1 32 48
390
- controlnet_image_index = [0],
391
- controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
392
- control_mode = drag_mode,
393
- eval_mode = True,
394
- ).videos
395
-
396
- # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
397
- # vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
398
- # torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
399
-
400
- outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
401
- save_videos_grid(sample[0][None], outputs_path)
402
- print("Done!")
403
- return {output_image: visualized_drag, output_video: outputs_path}
404
-
405
-
406
- def reset_states(first_frame_path, tracking_points):
407
- first_frame_path = gr.State()
408
- tracking_points = gr.State([])
409
- return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
410
-
411
-
412
- def preprocess_image(image, tracking_points):
413
- image_pil = image2pil(image.name)
414
- raw_w, raw_h = image_pil.size
415
- resize_ratio = max(384/raw_w, 256/raw_h)
416
- image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
417
- image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
418
- id = str(uuid.uuid4())[:4]
419
- first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
420
- image_pil.save(first_frame_path, quality=95)
421
- tracking_points = gr.State([])
422
- return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points, personalized:""}
423
-
424
-
425
- def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
426
- if drag_mode=='object':
427
- color = (255, 0, 0, 255)
428
- elif drag_mode=='camera':
429
- color = (0, 0, 255, 255)
430
-
431
- if not isinstance(tracking_points ,list):
432
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
433
- tracking_points.value[-1].append(evt.index)
434
- print(tracking_points.value)
435
- tracking_points_values = tracking_points.value
436
- else:
437
- try:
438
- tracking_points[-1].append(evt.index)
439
- except Exception as e:
440
- tracking_points.append([])
441
- tracking_points[-1].append(evt.index)
442
- print(f"Solved Error: {e}")
443
-
444
- tracking_points_values = tracking_points
445
-
446
-
447
- transparent_background = Image.open(first_frame_path).convert('RGBA')
448
- w, h = transparent_background.size
449
- transparent_layer = np.zeros((h, w, 4))
450
-
451
- for track in tracking_points_values:
452
- if len(track) > 1:
453
- for i in range(len(track)-1):
454
- start_point = track[i]
455
- end_point = track[i+1]
456
- vx = end_point[0] - start_point[0]
457
- vy = end_point[1] - start_point[1]
458
- arrow_length = np.sqrt(vx**2 + vy**2)
459
- if i == len(track)-2:
460
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
461
- else:
462
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
463
- else:
464
- cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
465
-
466
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
467
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
468
- return {tracking_points_var: tracking_points, input_image: trajectory_map}
469
-
470
-
471
- def add_drag(tracking_points):
472
- if not isinstance(tracking_points ,list):
473
- # print("before", tracking_points.value)
474
- tracking_points.value.append([])
475
- # print(tracking_points.value)
476
- else:
477
- tracking_points.append([])
478
- return {tracking_points_var: tracking_points}
479
-
480
-
481
- def delete_last_drag(tracking_points, first_frame_path, drag_mode):
482
- if drag_mode=='object':
483
- color = (255, 0, 0, 255)
484
- elif drag_mode=='camera':
485
- color = (0, 0, 255, 255)
486
- tracking_points.value.pop()
487
- transparent_background = Image.open(first_frame_path).convert('RGBA')
488
- w, h = transparent_background.size
489
- transparent_layer = np.zeros((h, w, 4))
490
- for track in tracking_points.value:
491
- if len(track) > 1:
492
- for i in range(len(track)-1):
493
- start_point = track[i]
494
- end_point = track[i+1]
495
- vx = end_point[0] - start_point[0]
496
- vy = end_point[1] - start_point[1]
497
- arrow_length = np.sqrt(vx**2 + vy**2)
498
- if i == len(track)-2:
499
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
500
- else:
501
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
502
- else:
503
- cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
504
-
505
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
506
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
507
- return {tracking_points_var: tracking_points, input_image: trajectory_map}
508
-
509
-
510
- def delete_last_step(tracking_points, first_frame_path, drag_mode):
511
- if drag_mode=='object':
512
- color = (255, 0, 0, 255)
513
- elif drag_mode=='camera':
514
- color = (0, 0, 255, 255)
515
- tracking_points.value[-1].pop()
516
- transparent_background = Image.open(first_frame_path).convert('RGBA')
517
- w, h = transparent_background.size
518
- transparent_layer = np.zeros((h, w, 4))
519
- for track in tracking_points.value:
520
- if len(track) > 1:
521
- for i in range(len(track)-1):
522
- start_point = track[i]
523
- end_point = track[i+1]
524
- vx = end_point[0] - start_point[0]
525
- vy = end_point[1] - start_point[1]
526
- arrow_length = np.sqrt(vx**2 + vy**2)
527
- if i == len(track)-2:
528
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
529
- else:
530
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
531
- else:
532
- cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
533
-
534
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
535
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
536
- return {tracking_points_var: tracking_points, input_image: trajectory_map}
537
-
538
-
539
- block = gr.Blocks(
540
- theme=gr.themes.Soft(
541
- radius_size=gr.themes.sizes.radius_none,
542
- text_size=gr.themes.sizes.text_md
543
- )
544
- )
545
- with block:
546
- with gr.Row():
547
- with gr.Column():
548
- gr.HTML(head)
549
-
550
- gr.Markdown(descriptions)
551
-
552
- with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
553
- with gr.Row(equal_height=True):
554
- gr.Markdown(instructions)
555
-
556
-
557
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
558
- device = torch.device("cuda")
559
- unet_path = 'models/unet.ckpt'
560
- image_controlnet_path = 'models/image_controlnet.ckpt'
561
- flow_controlnet_path = 'models/flow_controlnet.ckpt'
562
- ImageConductor_net = ImageConductor(device=device,
563
- unet_path=unet_path,
564
- image_controlnet_path=image_controlnet_path,
565
- flow_controlnet_path=flow_controlnet_path,
566
- height=256,
567
- width=384,
568
- model_length=16
569
- )
570
- first_frame_path_var = gr.State(value=None)
571
- tracking_points_var = gr.State([])
572
-
573
- with gr.Row():
574
- with gr.Column(scale=1):
575
- image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
576
- add_drag_button = gr.Button(value="Add Drag")
577
- reset_button = gr.Button(value="Reset")
578
- delete_last_drag_button = gr.Button(value="Delete last drag")
579
- delete_last_step_button = gr.Button(value="Delete last step")
580
-
581
-
582
-
583
- with gr.Column(scale=7):
584
- with gr.Row():
585
- with gr.Column(scale=6):
586
- input_image = gr.Image(label="Input Image",
587
- interactive=True,
588
- height=300,
589
- width=384,)
590
- with gr.Column(scale=6):
591
- output_image = gr.Image(label="Motion Path",
592
- interactive=False,
593
- height=256,
594
- width=384,)
595
- with gr.Row():
596
- with gr.Column(scale=1):
597
- prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
598
- negative_prompt = gr.Text(
599
- label="Negative Prompt",
600
- max_lines=5,
601
- placeholder="Please input your negative prompt",
602
- value='worst quality, low quality, letterboxed',lines=1
603
- )
604
- drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
605
- run_button = gr.Button(value="Run")
606
-
607
- with gr.Accordion("More input params", open=False, elem_id="accordion1"):
608
- with gr.Group():
609
- seed = gr.Textbox(
610
- label="Seed: ", value=561793204,
611
- )
612
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
613
-
614
- with gr.Group():
615
- with gr.Row():
616
- guidance_scale = gr.Slider(
617
- label="Guidance scale",
618
- minimum=1,
619
- maximum=12,
620
- step=0.1,
621
- value=8.5,
622
- )
623
- num_inference_steps = gr.Slider(
624
- label="Number of inference steps",
625
- minimum=1,
626
- maximum=50,
627
- step=1,
628
- value=25,
629
- )
630
-
631
- with gr.Group():
632
- personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
633
- examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
634
-
635
- with gr.Column(scale=7):
636
- # output_video = gr.Video(
637
- # label="Output Video",
638
- # width=384,
639
- # height=256)
640
- output_video = gr.Image(label="Output Video",
641
- height=256,
642
- width=384,)
643
-
644
-
645
- with gr.Row():
646
-
647
-
648
- example = gr.Examples(
649
- label="Input Example",
650
- examples=image_examples,
651
- inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
652
- examples_per_page=10,
653
- cache_examples=False,
654
- )
655
-
656
-
657
- with gr.Row():
658
- gr.Markdown(citation)
659
-
660
-
661
- image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var, personalized])
662
-
663
- add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
664
-
665
- delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
666
-
667
- delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
668
-
669
- reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
670
-
671
- input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
672
-
673
- run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
674
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
675
- [output_image, output_video])
676
-
677
- block.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -4,8 +4,8 @@ import sys
4
 
5
  print("Installing correct gradio version...")
6
  os.system("pip uninstall -y gradio")
7
- os.system("pip install gradio==4.7.0")
8
- # print("Installing Finished")
9
 
10
 
11
  import gradio as gr
@@ -150,70 +150,75 @@ image_examples = [
150
  "object",
151
  11318446767408804497,
152
  "",
153
- "__asset__/images/object/turtle-1.jpg",
154
- json.load(open("__asset__/trajs/object/turtle-1.json"))
155
  ],
156
 
157
- # ["__asset__/images/object/rose-1.jpg",
158
- # "a red rose engulfed in flames.",
159
- # "object",
160
- # 6854275249656120509,
161
- # "",
162
- # "rose",
163
- # ],
 
164
 
165
- # ["__asset__/images/object/jellyfish-1.jpg",
166
- # "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
167
- # "object",
168
- # 17966188172968903484,
169
- # "HelloObject",
170
- # "jellyfish"
171
- # ],
 
172
 
173
 
174
- # ["__asset__/images/camera/lush-1.jpg",
175
- # "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
176
- # "camera",
177
- # 7970487946960948963,
178
- # "HelloObject",
179
- # "lush",
180
- # ],
 
181
 
182
- # ["__asset__/images/camera/tusun-1.jpg",
183
- # "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
184
- # "camera",
185
- # 996953226890228361,
186
- # "TUSUN",
187
- # "tusun",
188
- # ],
 
189
 
190
- # ["__asset__/images/camera/painting-1.jpg",
191
- # "A oil painting.",
192
- # "camera",
193
- # 16867854766769816385,
194
- # "",
195
- # "painting"
196
- # ],
 
197
  ]
198
 
199
 
200
- # POINTS = {
201
- # 'turtle': "__asset__/trajs/object/turtle-1.json",
202
- # 'rose': "__asset__/trajs/object/rose-1.json",
203
- # 'jellyfish': "__asset__/trajs/object/jellyfish-1.json",
204
- # 'lush': "__asset__/trajs/camera/lush-1.json",
205
- # 'tusun': "__asset__/trajs/camera/tusun-1.json",
206
- # 'painting': "__asset__/trajs/camera/painting-1.json",
207
- # }
208
 
209
- # IMAGE_PATH = {
210
- # 'turtle': "__asset__/images/object/turtle-1.jpg",
211
- # 'rose': "__asset__/images/object/rose-1.jpg",
212
- # 'jellyfish': "__asset__/images/object/jellyfish-1.jpg",
213
- # 'lush': "__asset__/images/camera/lush-1.jpg",
214
- # 'tusun': "__asset__/images/camera/tusun-1.jpg",
215
- # 'painting': "__asset__/images/camera/painting-1.jpg",
216
- # }
217
 
218
 
219
 
@@ -301,27 +306,31 @@ class ImageConductor:
301
 
302
  self.blur_kernel = blur_kernel
303
 
304
- @spaces.GPU(duration=120)
305
- def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized,):
306
-
307
-
 
 
 
 
 
 
 
 
308
  original_width, original_height=384, 256
309
  if isinstance(tracking_points, list):
310
  input_all_points = tracking_points
311
  else:
312
  input_all_points = tracking_points.value
313
 
314
-
315
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
316
 
317
  dir, base, ext = split_filename(first_frame_path)
318
  id = base.split('_')[-1]
319
 
320
 
321
- # with open(f'{output_dir}/points-{id}.json', 'w') as f:
322
- # json.dump(input_all_points, f)
323
-
324
-
325
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
326
 
327
  ## image condition
@@ -333,9 +342,8 @@ class ImageConductor:
333
  transforms.ToTensor(),
334
  ])
335
 
336
- image_norm = lambda x: x
337
  image_paths = [first_frame_path]
338
- controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
339
  controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
340
  controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
341
  num_controlnet_images = controlnet_images.shape[2]
@@ -391,14 +399,13 @@ class ImageConductor:
391
  eval_mode = True,
392
  ).videos
393
 
394
- # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
395
- # vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
396
- # torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
397
-
398
 
399
- outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
400
- save_videos_grid(sample[0][None], outputs_path)
401
-
402
  return {output_image: visualized_drag, output_video: outputs_path}
403
 
404
 
@@ -408,7 +415,7 @@ def reset_states(first_frame_path, tracking_points):
408
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
409
 
410
 
411
- def preprocess_image(image):
412
  image_pil = image2pil(image.name)
413
  raw_w, raw_h = image_pil.size
414
  resize_ratio = max(384/raw_w, 256/raw_h)
@@ -417,7 +424,8 @@ def preprocess_image(image):
417
  id = str(uuid.uuid4())[:4]
418
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
419
  image_pil.save(first_frame_path, quality=95)
420
- return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: gr.State([]), personalized: ""}
 
421
 
422
 
423
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
@@ -426,14 +434,27 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
426
  elif drag_mode=='camera':
427
  color = (0, 0, 255, 255)
428
 
429
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
430
- tracking_points.value[-1].append(evt.index)
431
- print(tracking_points.value)
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  transparent_background = Image.open(first_frame_path).convert('RGBA')
434
  w, h = transparent_background.size
435
  transparent_layer = np.zeros((h, w, 4))
436
- for track in tracking_points.value:
 
437
  if len(track) > 1:
438
  for i in range(len(track)-1):
439
  start_point = track[i]
@@ -454,9 +475,12 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
454
 
455
 
456
  def add_drag(tracking_points):
457
- # import ipdb; ipdb.set_trace()
458
- tracking_points.value.append([])
459
- print(tracking_points.value)
 
 
 
460
  return {tracking_points_var: tracking_points}
461
 
462
 
@@ -518,144 +542,142 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
518
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
519
 
520
 
521
- if __name__=="__main__":
522
- block = gr.Blocks(
523
- theme=gr.themes.Soft(
524
- radius_size=gr.themes.sizes.radius_none,
525
- text_size=gr.themes.sizes.text_md
526
- )
527
- ).queue()
528
- with block as demo:
529
- with gr.Row():
530
- with gr.Column():
531
- gr.HTML(head)
532
-
533
- gr.Markdown(descriptions)
534
-
535
- with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
536
- with gr.Row(equal_height=True):
537
- gr.Markdown(instructions)
538
-
539
-
540
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
541
- device = torch.device("cuda")
542
- unet_path = 'models/unet.ckpt'
543
- image_controlnet_path = 'models/image_controlnet.ckpt'
544
- flow_controlnet_path = 'models/flow_controlnet.ckpt'
545
- ImageConductor_net = ImageConductor(device=device,
546
- unet_path=unet_path,
547
- image_controlnet_path=image_controlnet_path,
548
- flow_controlnet_path=flow_controlnet_path,
549
- height=256,
550
- width=384,
551
- model_length=16
552
- )
553
- first_frame_path_var = gr.State(value=None)
554
- tracking_points_var = gr.State([])
555
-
556
- with gr.Row():
557
- with gr.Column(scale=1):
558
- image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
559
- add_drag_button = gr.Button(value="Add Drag")
560
- reset_button = gr.Button(value="Reset")
561
- delete_last_drag_button = gr.Button(value="Delete last drag")
562
- delete_last_step_button = gr.Button(value="Delete last step")
563
-
564
-
565
-
566
- with gr.Column(scale=7):
567
- with gr.Row():
568
- with gr.Column(scale=6):
569
- input_image = gr.Image(label="Input Image",
570
- interactive=True,
571
- height=300,
572
- width=384,)
573
- with gr.Column(scale=6):
574
- output_image = gr.Image(label="Motion Path",
575
- interactive=False,
576
- height=256,
577
- width=384,)
578
- with gr.Row():
579
- with gr.Column(scale=1):
580
- prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
581
- negative_prompt = gr.Text(
582
- label="Negative Prompt",
583
- max_lines=5,
584
- placeholder="Please input your negative prompt",
585
- value='worst quality, low quality, letterboxed',lines=1
586
- )
587
- drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
588
- run_button = gr.Button(value="Run")
589
 
590
- with gr.Accordion("More input params", open=False, elem_id="accordion1"):
591
- with gr.Group():
592
- seed = gr.Textbox(
593
- label="Seed: ", value=561793204,
594
- )
595
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
596
-
597
- with gr.Group():
598
- with gr.Row():
599
- guidance_scale = gr.Slider(
600
- label="Guidance scale",
601
- minimum=1,
602
- maximum=12,
603
- step=0.1,
604
- value=8.5,
605
- )
606
- num_inference_steps = gr.Slider(
607
- label="Number of inference steps",
608
- minimum=1,
609
- maximum=50,
610
- step=1,
611
- value=25,
612
- )
613
-
614
- with gr.Group():
615
- personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
616
- # examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
617
-
618
- with gr.Column(scale=7):
619
- # output_video = gr.Video(
620
- # label="Output Video",
621
- # width=384,
622
- # height=256)
623
- output_video = gr.Image(label="Output Video",
624
  height=256,
625
  width=384,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
 
627
-
628
- with gr.Row():
629
- def process_examples(input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var):
630
- return input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var
631
- example = gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  label="Input Example",
633
  examples=image_examples,
634
- inputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
635
- outputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
636
- fn=process_examples,
637
  examples_per_page=10,
638
  cache_examples=False,
639
  )
640
-
641
- with gr.Row():
642
- gr.Markdown(citation)
643
-
644
 
645
- image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path_var, tracking_points_var, personalized])
 
 
 
 
 
646
 
647
- add_drag_button.click(add_drag, [tracking_points_var], tracking_points_var)
648
 
649
- delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
650
 
651
- delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
652
 
653
- reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
654
 
655
- input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
656
 
657
- run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
658
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized],
659
- [output_image, output_video])
660
 
661
  block.queue().launch()
 
4
 
5
  print("Installing correct gradio version...")
6
  os.system("pip uninstall -y gradio")
7
+ os.system("pip install gradio==4.37.2")
8
+ print("Installing Finished!")
9
 
10
 
11
  import gradio as gr
 
150
  "object",
151
  11318446767408804497,
152
  "",
153
+ "turtle",
154
+ "__asset__/turtle.mp4"
155
  ],
156
 
157
+ ["__asset__/images/object/rose-1.jpg",
158
+ "a red rose engulfed in flames.",
159
+ "object",
160
+ 6854275249656120509,
161
+ "",
162
+ "rose",
163
+ "__asset__/rose.mp4"
164
+ ],
165
 
166
+ ["__asset__/images/object/jellyfish-1.jpg",
167
+ "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
168
+ "object",
169
+ 17966188172968903484,
170
+ "HelloObject",
171
+ "jellyfish",
172
+ "__asset__/jellyfish.mp4"
173
+ ],
174
 
175
 
176
+ ["__asset__/images/camera/lush-1.jpg",
177
+ "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
178
+ "camera",
179
+ 7970487946960948963,
180
+ "HelloObject",
181
+ "lush",
182
+ "__asset__/lush.mp4",
183
+ ],
184
 
185
+ ["__asset__/images/camera/tusun-1.jpg",
186
+ "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
187
+ "camera",
188
+ 996953226890228361,
189
+ "TUSUN",
190
+ "tusun",
191
+ "__asset__/tusun.mp4"
192
+ ],
193
 
194
+ ["__asset__/images/camera/painting-1.jpg",
195
+ "A oil painting.",
196
+ "camera",
197
+ 16867854766769816385,
198
+ "",
199
+ "painting",
200
+ "__asset__/painting.mp4"
201
+ ],
202
  ]
203
 
204
 
205
+ POINTS = {
206
+ 'turtle': "__asset__/trajs/object/turtle-1.json",
207
+ 'rose': "__asset__/trajs/object/rose-1.json",
208
+ 'jellyfish': "__asset__/trajs/object/jellyfish-1.json",
209
+ 'lush': "__asset__/trajs/camera/lush-1.json",
210
+ 'tusun': "__asset__/trajs/camera/tusun-1.json",
211
+ 'painting': "__asset__/trajs/camera/painting-1.json",
212
+ }
213
 
214
+ IMAGE_PATH = {
215
+ 'turtle': "__asset__/images/object/turtle-1.jpg",
216
+ 'rose': "__asset__/images/object/rose-1.jpg",
217
+ 'jellyfish': "__asset__/images/object/jellyfish-1.jpg",
218
+ 'lush': "__asset__/images/camera/lush-1.jpg",
219
+ 'tusun': "__asset__/images/camera/tusun-1.jpg",
220
+ 'painting': "__asset__/images/camera/painting-1.jpg",
221
+ }
222
 
223
 
224
 
 
306
 
307
  self.blur_kernel = blur_kernel
308
 
309
+ @spaces.GPU(duration=180)
310
+ def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type):
311
+ print("Run!")
312
+ if examples_type != "":
313
+ ### for adapting high version gradio
314
+ tracking_points = gr.State([])
315
+ first_frame_path = IMAGE_PATH[examples_type]
316
+ points = json.load(open(POINTS[examples_type]))
317
+ tracking_points.value.extend(points)
318
+ print("example first_frame_path", first_frame_path)
319
+ print("example tracking_points", tracking_points.value)
320
+
321
  original_width, original_height=384, 256
322
  if isinstance(tracking_points, list):
323
  input_all_points = tracking_points
324
  else:
325
  input_all_points = tracking_points.value
326
 
327
+ print("input_all_points", input_all_points)
328
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
329
 
330
  dir, base, ext = split_filename(first_frame_path)
331
  id = base.split('_')[-1]
332
 
333
 
 
 
 
 
334
  visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
335
 
336
  ## image condition
 
342
  transforms.ToTensor(),
343
  ])
344
 
 
345
  image_paths = [first_frame_path]
346
+ controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
347
  controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
348
  controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
349
  num_controlnet_images = controlnet_images.shape[2]
 
399
  eval_mode = True,
400
  ).videos
401
 
402
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
403
+ vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
404
+ torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
 
405
 
406
+ # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
407
+ # save_videos_grid(sample[0][None], outputs_path)
408
+ print("Done!")
409
  return {output_image: visualized_drag, output_video: outputs_path}
410
 
411
 
 
415
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
416
 
417
 
418
+ def preprocess_image(image, tracking_points):
419
  image_pil = image2pil(image.name)
420
  raw_w, raw_h = image_pil.size
421
  resize_ratio = max(384/raw_w, 256/raw_h)
 
424
  id = str(uuid.uuid4())[:4]
425
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
426
  image_pil.save(first_frame_path, quality=95)
427
+ tracking_points = gr.State([])
428
+ return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points, personalized:""}
429
 
430
 
431
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
 
434
  elif drag_mode=='camera':
435
  color = (0, 0, 255, 255)
436
 
437
+ if not isinstance(tracking_points ,list):
438
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
439
+ tracking_points.value[-1].append(evt.index)
440
+ print(tracking_points.value)
441
+ tracking_points_values = tracking_points.value
442
+ else:
443
+ try:
444
+ tracking_points[-1].append(evt.index)
445
+ except Exception as e:
446
+ tracking_points.append([])
447
+ tracking_points[-1].append(evt.index)
448
+ print(f"Solved Error: {e}")
449
+
450
+ tracking_points_values = tracking_points
451
+
452
 
453
  transparent_background = Image.open(first_frame_path).convert('RGBA')
454
  w, h = transparent_background.size
455
  transparent_layer = np.zeros((h, w, 4))
456
+
457
+ for track in tracking_points_values:
458
  if len(track) > 1:
459
  for i in range(len(track)-1):
460
  start_point = track[i]
 
475
 
476
 
477
  def add_drag(tracking_points):
478
+ if not isinstance(tracking_points ,list):
479
+ # print("before", tracking_points.value)
480
+ tracking_points.value.append([])
481
+ # print(tracking_points.value)
482
+ else:
483
+ tracking_points.append([])
484
  return {tracking_points_var: tracking_points}
485
 
486
 
 
542
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
543
 
544
 
545
+ block = gr.Blocks(
546
+ theme=gr.themes.Soft(
547
+ radius_size=gr.themes.sizes.radius_none,
548
+ text_size=gr.themes.sizes.text_md
549
+ )
550
+ )
551
+ with block:
552
+ with gr.Row():
553
+ with gr.Column():
554
+ gr.HTML(head)
555
+
556
+ gr.Markdown(descriptions)
557
+
558
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
559
+ with gr.Row(equal_height=True):
560
+ gr.Markdown(instructions)
561
+
562
+
563
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
564
+ device = torch.device("cuda")
565
+ unet_path = 'models/unet.ckpt'
566
+ image_controlnet_path = 'models/image_controlnet.ckpt'
567
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
568
+ ImageConductor_net = ImageConductor(device=device,
569
+ unet_path=unet_path,
570
+ image_controlnet_path=image_controlnet_path,
571
+ flow_controlnet_path=flow_controlnet_path,
572
+ height=256,
573
+ width=384,
574
+ model_length=16
575
+ )
576
+ first_frame_path_var = gr.State(value=None)
577
+ tracking_points_var = gr.State([])
578
+
579
+ with gr.Row():
580
+ with gr.Column(scale=1):
581
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
582
+ add_drag_button = gr.Button(value="Add Drag")
583
+ reset_button = gr.Button(value="Reset")
584
+ delete_last_drag_button = gr.Button(value="Delete last drag")
585
+ delete_last_step_button = gr.Button(value="Delete last step")
586
+
587
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
+ with gr.Column(scale=7):
590
+ with gr.Row():
591
+ with gr.Column(scale=6):
592
+ input_image = gr.Image(label="Input Image",
593
+ interactive=True,
594
+ height=300,
595
+ width=384,)
596
+ with gr.Column(scale=6):
597
+ output_image = gr.Image(label="Motion Path",
598
+ interactive=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  height=256,
600
  width=384,)
601
+ with gr.Row():
602
+ with gr.Column(scale=1):
603
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
604
+ negative_prompt = gr.Text(
605
+ label="Negative Prompt",
606
+ max_lines=5,
607
+ placeholder="Please input your negative prompt",
608
+ value='worst quality, low quality, letterboxed',lines=1
609
+ )
610
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
611
+ run_button = gr.Button(value="Run")
612
+
613
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
614
+ with gr.Group():
615
+ seed = gr.Textbox(
616
+ label="Seed: ", value=561793204,
617
+ )
618
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
619
 
620
+ with gr.Group():
621
+ with gr.Row():
622
+ guidance_scale = gr.Slider(
623
+ label="Guidance scale",
624
+ minimum=1,
625
+ maximum=12,
626
+ step=0.1,
627
+ value=8.5,
628
+ )
629
+ num_inference_steps = gr.Slider(
630
+ label="Number of inference steps",
631
+ minimum=1,
632
+ maximum=50,
633
+ step=1,
634
+ value=25,
635
+ )
636
+
637
+ with gr.Group():
638
+ personalized = gr.Dropdown(label="Personalized", choices=["", 'HelloObject', 'TUSUN'], value="")
639
+ examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
640
+
641
+ with gr.Column(scale=7):
642
+ output_video = gr.Video(
643
+ label="Output Video",
644
+ width=384,
645
+ height=256)
646
+ # output_video = gr.Image(label="Output Video",
647
+ # height=256,
648
+ # width=384,)
649
+
650
+
651
+ with gr.Row():
652
+
653
+
654
+ example = gr.Examples(
655
  label="Input Example",
656
  examples=image_examples,
657
+ inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type, output_video],
 
 
658
  examples_per_page=10,
659
  cache_examples=False,
660
  )
 
 
 
 
661
 
662
+
663
+ with gr.Row():
664
+ gr.Markdown(citation)
665
+
666
+
667
+ image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var, personalized])
668
 
669
+ add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
670
 
671
+ delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
672
 
673
+ delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
674
 
675
+ reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
676
 
677
+ input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
678
 
679
+ run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
680
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
681
+ [output_image, output_video])
682
 
683
  block.queue().launch()