Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- .gitignore +2 -1
- app.py +469 -281
.gitignore
CHANGED
@@ -197,4 +197,5 @@ slurm-*.out
|
|
197 |
.vscode
|
198 |
|
199 |
data/
|
200 |
-
tmp/
|
|
|
|
197 |
.vscode
|
198 |
|
199 |
data/
|
200 |
+
tmp/
|
201 |
+
.gradio/
|
app.py
CHANGED
@@ -33,6 +33,9 @@ from submodules.MoGe.moge.model import MoGeModel
|
|
33 |
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
34 |
from submodules.vggt.vggt.models.vggt import VGGT
|
35 |
|
|
|
|
|
|
|
36 |
# Parse command line arguments
|
37 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
38 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
@@ -82,9 +85,9 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
82 |
|
83 |
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
84 |
if duration > 6.0:
|
85 |
-
|
86 |
-
frames = load_video(media_path,
|
87 |
-
fps =
|
88 |
# Cases 2 and 3: Video shorter than 6 seconds
|
89 |
else:
|
90 |
# Load all frames
|
@@ -195,10 +198,10 @@ def get_vggt_model():
|
|
195 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
196 |
"""Process video motion transfer task"""
|
197 |
try:
|
198 |
-
#
|
199 |
input_video_path = save_uploaded_file(source)
|
200 |
if input_video_path is None:
|
201 |
-
return None, None
|
202 |
|
203 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
204 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
@@ -253,28 +256,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
253 |
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
254 |
print('Export tracking video via cotracker')
|
255 |
|
256 |
-
|
257 |
-
|
258 |
-
fps=fps, # 使用 load_media 返回的 fps
|
259 |
-
tracking_tensor=tracking_tensor,
|
260 |
-
img_cond_tensor=repaint_img_tensor,
|
261 |
-
prompt=prompt,
|
262 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
263 |
-
)
|
264 |
-
|
265 |
-
return tracking_path, output_path
|
266 |
except Exception as e:
|
267 |
import traceback
|
268 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
269 |
-
return None, None
|
270 |
|
271 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
272 |
"""Process camera control task"""
|
273 |
try:
|
274 |
-
#
|
275 |
input_media_path = save_uploaded_file(source)
|
276 |
if input_media_path is None:
|
277 |
-
return None, None
|
278 |
|
279 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
280 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
@@ -317,24 +312,8 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
317 |
# 使用在CPU上运行的cotracker
|
318 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
319 |
|
320 |
-
|
321 |
-
|
322 |
-
new_height = round(h * (new_width / w) / 14) * 14
|
323 |
-
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
324 |
-
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
325 |
-
|
326 |
-
if new_height > 518:
|
327 |
-
start_y = (new_height - 518) // 2
|
328 |
-
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
329 |
-
|
330 |
-
vggt_model = get_vggt_model()
|
331 |
-
|
332 |
-
with torch.no_grad():
|
333 |
-
with torch.cuda.amp.autocast(dtype=das.dtype):
|
334 |
-
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
335 |
-
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
336 |
-
|
337 |
-
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
338 |
|
339 |
cam_motion.set_intr(intr)
|
340 |
cam_motion.set_extr(extr)
|
@@ -345,23 +324,15 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
345 |
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
346 |
print("Camera motion applied")
|
347 |
|
348 |
-
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks,
|
349 |
print('Export tracking video via cotracker')
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
fps=fps, # 使用 load_media 返回的 fps
|
354 |
-
tracking_tensor=tracking_tensor,
|
355 |
-
img_cond_tensor=repaint_img_tensor,
|
356 |
-
prompt=prompt,
|
357 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
358 |
-
)
|
359 |
-
|
360 |
-
return tracking_path, output_path
|
361 |
except Exception as e:
|
362 |
import traceback
|
363 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
364 |
-
return None, None
|
365 |
|
366 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
367 |
"""Process object manipulation task"""
|
@@ -369,12 +340,12 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
369 |
# Save uploaded files
|
370 |
input_image_path = save_uploaded_file(source)
|
371 |
if input_image_path is None:
|
372 |
-
return None, None
|
373 |
|
374 |
object_mask_path = save_uploaded_file(object_mask)
|
375 |
if object_mask_path is None:
|
376 |
print("Object mask not provided")
|
377 |
-
return None, None
|
378 |
|
379 |
das = get_das_pipeline()
|
380 |
video_tensor, fps, is_video = load_media(input_image_path)
|
@@ -424,24 +395,8 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
424 |
# 使用在CPU上运行的cotracker
|
425 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
new_height = round(h * (new_width / w) / 14) * 14
|
430 |
-
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
431 |
-
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
432 |
-
|
433 |
-
if new_height > 518:
|
434 |
-
start_y = (new_height - 518) // 2
|
435 |
-
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
436 |
-
|
437 |
-
vggt_model = get_vggt_model()
|
438 |
-
|
439 |
-
with torch.no_grad():
|
440 |
-
with torch.cuda.amp.autocast(dtype=das.dtype):
|
441 |
-
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
442 |
-
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
443 |
-
|
444 |
-
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
445 |
|
446 |
pred_tracks = motion_generator.apply_motion(
|
447 |
pred_tracks=pred_tracks.squeeze(),
|
@@ -453,23 +408,15 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
453 |
)
|
454 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
455 |
|
456 |
-
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0),
|
457 |
print('Export tracking video via cotracker')
|
458 |
|
459 |
-
|
460 |
-
|
461 |
-
fps=fps, # 使用 load_media 返回的 fps
|
462 |
-
tracking_tensor=tracking_tensor,
|
463 |
-
img_cond_tensor=repaint_img_tensor,
|
464 |
-
prompt=prompt,
|
465 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
466 |
-
)
|
467 |
-
|
468 |
-
return tracking_path, output_path
|
469 |
except Exception as e:
|
470 |
import traceback
|
471 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
472 |
-
return None, None
|
473 |
|
474 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
475 |
"""Process mesh animation task"""
|
@@ -477,11 +424,11 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
477 |
# Save uploaded files
|
478 |
input_video_path = save_uploaded_file(source)
|
479 |
if input_video_path is None:
|
480 |
-
return None, None
|
481 |
|
482 |
tracking_video_path = save_uploaded_file(tracking_video)
|
483 |
if tracking_video_path is None:
|
484 |
-
return None, None
|
485 |
|
486 |
das = get_das_pipeline()
|
487 |
video_tensor, fps, is_video = load_media(input_video_path)
|
@@ -494,7 +441,6 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
494 |
repaint_img_tensor, _, _ = load_media(repaint_path)
|
495 |
repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
|
496 |
elif ma_repaint_option == "Yes":
|
497 |
-
|
498 |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
499 |
repaint_img_tensor = repainter.repaint(
|
500 |
video_tensor[0],
|
@@ -502,20 +448,12 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
502 |
depth_path=None
|
503 |
)
|
504 |
|
505 |
-
|
506 |
-
|
507 |
-
fps=fps, # 使用 load_media 返回的 fps
|
508 |
-
tracking_tensor=tracking_tensor,
|
509 |
-
img_cond_tensor=repaint_img_tensor,
|
510 |
-
prompt=prompt,
|
511 |
-
checkpoint_path=DEFAULT_MODEL_PATH
|
512 |
-
)
|
513 |
-
|
514 |
-
return tracking_video_path, output_path
|
515 |
except Exception as e:
|
516 |
import traceback
|
517 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
518 |
-
return None, None
|
519 |
|
520 |
def generate_tracking_cotracker(video_tensor, density=30):
|
521 |
"""在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
|
@@ -569,22 +507,192 @@ def generate_tracking_cotracker(video_tensor, density=30):
|
|
569 |
# 将结果返回
|
570 |
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
# Create Gradio interface with updated layout
|
573 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
574 |
gr.Markdown("# Diffusion as Shader Web UI")
|
575 |
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
|
576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
with gr.Row():
|
578 |
left_column = gr.Column(scale=1)
|
579 |
right_column = gr.Column(scale=1)
|
580 |
|
581 |
with right_column:
|
582 |
-
output_video = gr.Video(label="Generated Video")
|
583 |
tracking_video = gr.Video(label="Tracking Video")
|
|
|
|
|
|
|
|
|
584 |
|
585 |
with left_column:
|
586 |
-
|
587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
gr.Markdown(f"**Using GPU: {GPU_ID}**")
|
589 |
|
590 |
with gr.Tabs() as task_tabs:
|
@@ -600,228 +708,308 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
600 |
value="No"
|
601 |
)
|
602 |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
|
|
|
|
|
|
|
|
|
|
607 |
)
|
608 |
|
609 |
# Add run button for Motion Transfer tab
|
610 |
-
mt_run_btn = gr.Button("
|
611 |
|
612 |
-
# Connect to process function
|
613 |
mt_run_btn.click(
|
614 |
fn=process_motion_transfer,
|
615 |
inputs=[
|
616 |
-
|
617 |
-
mt_repaint_option,
|
618 |
],
|
619 |
-
outputs=[tracking_video,
|
|
|
|
|
|
|
|
|
620 |
)
|
621 |
|
622 |
-
# Camera Control tab
|
623 |
-
with gr.TabItem("Camera Control"):
|
624 |
-
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
|
652 |
-
|
653 |
-
|
654 |
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
|
667 |
-
|
668 |
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
|
696 |
-
|
697 |
-
|
698 |
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
|
711 |
-
|
712 |
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
|
723 |
-
|
724 |
-
|
725 |
|
726 |
-
|
727 |
-
|
728 |
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
|
741 |
-
|
742 |
-
|
743 |
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
|
|
753 |
|
754 |
-
# Object Manipulation tab
|
755 |
-
with gr.TabItem("Object Manipulation"):
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
|
773 |
-
|
774 |
-
|
775 |
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
|
|
|
|
|
|
|
|
785 |
|
786 |
-
# Animating meshes to video tab
|
787 |
-
with gr.TabItem("Animating meshes to video"):
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
|
|
|
|
798 |
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
|
813 |
-
|
814 |
-
|
815 |
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
|
826 |
# Launch interface
|
827 |
if __name__ == "__main__":
|
@@ -831,4 +1019,4 @@ if __name__ == "__main__":
|
|
831 |
print("Creating public link for remote access")
|
832 |
|
833 |
# Launch interface
|
834 |
-
demo.launch(share=args.share, server_port=args.port)
|
|
|
33 |
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
34 |
from submodules.vggt.vggt.models.vggt import VGGT
|
35 |
|
36 |
+
import torch._dynamo
|
37 |
+
torch._dynamo.config.suppress_errors = True
|
38 |
+
|
39 |
# Parse command line arguments
|
40 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
41 |
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
|
|
|
85 |
|
86 |
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
87 |
if duration > 6.0:
|
88 |
+
# 使用 max_frames 参数而不是 sampling_fps
|
89 |
+
frames = load_video(media_path, max_frames=max_frames)
|
90 |
+
fps = max_frames / 6.0 # 计算等效的 fps
|
91 |
# Cases 2 and 3: Video shorter than 6 seconds
|
92 |
else:
|
93 |
# Load all frames
|
|
|
198 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
199 |
"""Process video motion transfer task"""
|
200 |
try:
|
201 |
+
# 保存上传的文件
|
202 |
input_video_path = save_uploaded_file(source)
|
203 |
if input_video_path is None:
|
204 |
+
return None, None, None
|
205 |
|
206 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
207 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
|
|
256 |
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
257 |
print('Export tracking video via cotracker')
|
258 |
|
259 |
+
# 返回处理结果,但不应用跟踪
|
260 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
except Exception as e:
|
262 |
import traceback
|
263 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
264 |
+
return None, None, None, None, None
|
265 |
|
266 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
267 |
"""Process camera control task"""
|
268 |
try:
|
269 |
+
# 保存上传的文件
|
270 |
input_media_path = save_uploaded_file(source)
|
271 |
if input_media_path is None:
|
272 |
+
return None, None, None
|
273 |
|
274 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
275 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
|
|
312 |
# 使用在CPU上运行的cotracker
|
313 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
314 |
|
315 |
+
# 使用封装的 VGGT 处理函数
|
316 |
+
extr, intr = process_vggt(video_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
cam_motion.set_intr(intr)
|
319 |
cam_motion.set_extr(extr)
|
|
|
324 |
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
325 |
print("Camera motion applied")
|
326 |
|
327 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
328 |
print('Export tracking video via cotracker')
|
329 |
|
330 |
+
# 返回处理结果,但不应用跟踪
|
331 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
except Exception as e:
|
333 |
import traceback
|
334 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
335 |
+
return None, None, None, None, None
|
336 |
|
337 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
338 |
"""Process object manipulation task"""
|
|
|
340 |
# Save uploaded files
|
341 |
input_image_path = save_uploaded_file(source)
|
342 |
if input_image_path is None:
|
343 |
+
return None, None, None, None, None
|
344 |
|
345 |
object_mask_path = save_uploaded_file(object_mask)
|
346 |
if object_mask_path is None:
|
347 |
print("Object mask not provided")
|
348 |
+
return None, None, None, None, None
|
349 |
|
350 |
das = get_das_pipeline()
|
351 |
video_tensor, fps, is_video = load_media(input_image_path)
|
|
|
395 |
# 使用在CPU上运行的cotracker
|
396 |
pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
|
397 |
|
398 |
+
# 使用封装的 VGGT 处理函数
|
399 |
+
extr, intr = process_vggt(video_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
pred_tracks = motion_generator.apply_motion(
|
402 |
pred_tracks=pred_tracks.squeeze(),
|
|
|
408 |
)
|
409 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
410 |
|
411 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), pred_visibility)
|
412 |
print('Export tracking video via cotracker')
|
413 |
|
414 |
+
# 返回处理结果,但不应用跟踪
|
415 |
+
return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
except Exception as e:
|
417 |
import traceback
|
418 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
419 |
+
return None, None, None, None, None
|
420 |
|
421 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
422 |
"""Process mesh animation task"""
|
|
|
424 |
# Save uploaded files
|
425 |
input_video_path = save_uploaded_file(source)
|
426 |
if input_video_path is None:
|
427 |
+
return None, None, None, None, None
|
428 |
|
429 |
tracking_video_path = save_uploaded_file(tracking_video)
|
430 |
if tracking_video_path is None:
|
431 |
+
return None, None, None, None, None
|
432 |
|
433 |
das = get_das_pipeline()
|
434 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
441 |
repaint_img_tensor, _, _ = load_media(repaint_path)
|
442 |
repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
|
443 |
elif ma_repaint_option == "Yes":
|
|
|
444 |
repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
445 |
repaint_img_tensor = repainter.repaint(
|
446 |
video_tensor[0],
|
|
|
448 |
depth_path=None
|
449 |
)
|
450 |
|
451 |
+
# 直接返回上传的跟踪视频路径,而不是生成新的跟踪视频
|
452 |
+
return tracking_video_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
except Exception as e:
|
454 |
import traceback
|
455 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
456 |
+
return None, None, None, None, None
|
457 |
|
458 |
def generate_tracking_cotracker(video_tensor, density=30):
|
459 |
"""在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
|
|
|
507 |
# 将结果返回
|
508 |
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
509 |
|
510 |
+
@spaces.GPU(duration=240)
|
511 |
+
def apply_tracking_unified(video_tensor, tracking_tensor, repaint_img_tensor, prompt, fps):
|
512 |
+
"""统一的应用跟踪函数"""
|
513 |
+
try:
|
514 |
+
if video_tensor is None or tracking_tensor is None:
|
515 |
+
return None
|
516 |
+
|
517 |
+
das = get_das_pipeline()
|
518 |
+
output_path = das.apply_tracking(
|
519 |
+
video_tensor=video_tensor,
|
520 |
+
fps=fps,
|
521 |
+
tracking_tensor=tracking_tensor,
|
522 |
+
img_cond_tensor=repaint_img_tensor,
|
523 |
+
prompt=prompt,
|
524 |
+
checkpoint_path=DEFAULT_MODEL_PATH
|
525 |
+
)
|
526 |
+
|
527 |
+
print(f"生成的视频路径: {output_path}")
|
528 |
+
|
529 |
+
# 确保返回的是绝对路径
|
530 |
+
if output_path and not os.path.isabs(output_path):
|
531 |
+
output_path = os.path.abspath(output_path)
|
532 |
+
|
533 |
+
# 检查文件是否存在
|
534 |
+
if output_path and os.path.exists(output_path):
|
535 |
+
print(f"文件存在,大小: {os.path.getsize(output_path)} 字节")
|
536 |
+
return output_path
|
537 |
+
else:
|
538 |
+
print(f"警告: 输出文件不存在或路径无效: {output_path}")
|
539 |
+
return None
|
540 |
+
except Exception as e:
|
541 |
+
import traceback
|
542 |
+
print(f"Apply tracking failed: {str(e)}\n{traceback.format_exc()}")
|
543 |
+
return None
|
544 |
+
|
545 |
+
# 添加在 apply_tracking_unified 函数之后,Gradio 界面定义之前
|
546 |
+
|
547 |
+
def enable_apply_button(tracking_result):
|
548 |
+
"""当跟踪视频生成后启用应用按钮"""
|
549 |
+
if tracking_result is not None:
|
550 |
+
return gr.update(interactive=True)
|
551 |
+
return gr.update(interactive=False)
|
552 |
+
|
553 |
+
@spaces.GPU
|
554 |
+
def process_vggt(video_tensor):
|
555 |
+
vggt_model = get_vggt_model()
|
556 |
+
|
557 |
+
t, c, h, w = video_tensor.shape
|
558 |
+
new_width = 518
|
559 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
560 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
561 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
562 |
+
|
563 |
+
if new_height > 518:
|
564 |
+
start_y = (new_height - 518) // 2
|
565 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
566 |
+
|
567 |
+
with torch.no_grad():
|
568 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
569 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
570 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to("cuda"))
|
571 |
+
|
572 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
573 |
+
|
574 |
+
return extr, intr
|
575 |
+
|
576 |
+
def load_examples():
|
577 |
+
"""加载示例文件路径"""
|
578 |
+
samples_dir = os.path.join(project_root, "samples")
|
579 |
+
if not os.path.exists(samples_dir):
|
580 |
+
print(f"Warning: Samples directory not found at {samples_dir}")
|
581 |
+
return []
|
582 |
+
|
583 |
+
examples_list = []
|
584 |
+
|
585 |
+
# 为每个示例集创建一个示例项
|
586 |
+
# 示例1
|
587 |
+
example1 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
588 |
+
for filename in os.listdir(samples_dir):
|
589 |
+
if filename.startswith("sample1_"):
|
590 |
+
if filename.endswith("_raw.mp4"):
|
591 |
+
example1[0] = os.path.join(samples_dir, filename)
|
592 |
+
elif filename.endswith("_repaint.png"):
|
593 |
+
example1[1] = os.path.join(samples_dir, filename)
|
594 |
+
elif filename.endswith("_tracking.mp4"):
|
595 |
+
example1[3] = os.path.join(samples_dir, filename)
|
596 |
+
elif filename.endswith("_result.mp4"):
|
597 |
+
example1[4] = os.path.join(samples_dir, filename)
|
598 |
+
|
599 |
+
# 设置示例1的提示文本
|
600 |
+
example1[2] = "a rocket lifts off from the table and smoke erupt from its bottom."
|
601 |
+
|
602 |
+
# 示例2
|
603 |
+
example2 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
604 |
+
for filename in os.listdir(samples_dir):
|
605 |
+
if filename.startswith("sample2_"):
|
606 |
+
if filename.endswith("_raw.mp4"):
|
607 |
+
example2[0] = os.path.join(samples_dir, filename)
|
608 |
+
elif filename.endswith("_repaint.png"):
|
609 |
+
example2[1] = os.path.join(samples_dir, filename)
|
610 |
+
elif filename.endswith("_tracking.mp4"):
|
611 |
+
example2[3] = os.path.join(samples_dir, filename)
|
612 |
+
elif filename.endswith("_result.mp4"):
|
613 |
+
example2[4] = os.path.join(samples_dir, filename)
|
614 |
+
|
615 |
+
# 设置示例2的提示文本
|
616 |
+
example2[2] = "A wonderful bright old-fasion red car is riding from left to right sun light is shining on the car, its reflection glittering. In the background is a deserted city in the noon, the roads and buildings are covered with green vegetation."
|
617 |
+
|
618 |
+
# 添加示例到列表
|
619 |
+
if example1[0] is not None and example1[3] is not None:
|
620 |
+
examples_list.append(example1)
|
621 |
+
|
622 |
+
if example2[0] is not None and example2[3] is not None:
|
623 |
+
examples_list.append(example2)
|
624 |
+
|
625 |
+
# 添加其他示例(如果有)
|
626 |
+
sample_prefixes = set()
|
627 |
+
for filename in os.listdir(samples_dir):
|
628 |
+
if filename.endswith(('.mp4', '.png')):
|
629 |
+
prefix = filename.split('_')[0]
|
630 |
+
if prefix not in ["sample1", "sample2"]:
|
631 |
+
sample_prefixes.add(prefix)
|
632 |
+
|
633 |
+
for prefix in sorted(sample_prefixes):
|
634 |
+
example = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
|
635 |
+
for filename in os.listdir(samples_dir):
|
636 |
+
if filename.startswith(f"{prefix}_"):
|
637 |
+
if filename.endswith("_raw.mp4"):
|
638 |
+
example[0] = os.path.join(samples_dir, filename)
|
639 |
+
elif filename.endswith("_repaint.png"):
|
640 |
+
example[1] = os.path.join(samples_dir, filename)
|
641 |
+
elif filename.endswith("_tracking.mp4"):
|
642 |
+
example[3] = os.path.join(samples_dir, filename)
|
643 |
+
elif filename.endswith("_result.mp4"):
|
644 |
+
example[4] = os.path.join(samples_dir, filename)
|
645 |
+
|
646 |
+
# 添加默认提示文本
|
647 |
+
example[2] = "A beautiful scene"
|
648 |
+
|
649 |
+
# 只有当至少有源文件和跟踪视频时才添加示例
|
650 |
+
if example[0] is not None and example[3] is not None:
|
651 |
+
examples_list.append(example)
|
652 |
+
|
653 |
+
return examples_list
|
654 |
+
|
655 |
# Create Gradio interface with updated layout
|
656 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
657 |
gr.Markdown("# Diffusion as Shader Web UI")
|
658 |
gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
|
659 |
|
660 |
+
# 创建隐藏状态变量来存储中间结果
|
661 |
+
video_tensor_state = gr.State(None)
|
662 |
+
tracking_tensor_state = gr.State(None)
|
663 |
+
repaint_img_tensor_state = gr.State(None)
|
664 |
+
fps_state = gr.State(None)
|
665 |
+
|
666 |
with gr.Row():
|
667 |
left_column = gr.Column(scale=1)
|
668 |
right_column = gr.Column(scale=1)
|
669 |
|
670 |
with right_column:
|
|
|
671 |
tracking_video = gr.Video(label="Tracking Video")
|
672 |
+
|
673 |
+
# 初始状态下按钮不可用
|
674 |
+
apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False)
|
675 |
+
output_video = gr.Video(label="Generated Video")
|
676 |
|
677 |
with left_column:
|
678 |
+
source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"])
|
679 |
+
source_preview = gr.Video(label="Source Preview")
|
680 |
+
gr.Markdown("Upload a video or image, We will extract the motion and space structure from it")
|
681 |
+
|
682 |
+
# 上传文件后更新预览
|
683 |
+
def update_source_preview(file):
|
684 |
+
if file is None:
|
685 |
+
return None
|
686 |
+
path = save_uploaded_file(file)
|
687 |
+
return path
|
688 |
+
|
689 |
+
source_upload.upload(
|
690 |
+
fn=update_source_preview,
|
691 |
+
inputs=[source_upload],
|
692 |
+
outputs=[source_preview]
|
693 |
+
)
|
694 |
+
|
695 |
+
common_prompt = gr.Textbox(label="2. Prompt: Describe the scene and the motion you want to create", lines=2)
|
696 |
gr.Markdown(f"**Using GPU: {GPU_ID}**")
|
697 |
|
698 |
with gr.Tabs() as task_tabs:
|
|
|
708 |
value="No"
|
709 |
)
|
710 |
gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
711 |
+
|
712 |
+
mt_repaint_upload = gr.UploadButton("3. Upload Repaint Image (Optional)", file_types=["image"])
|
713 |
+
mt_repaint_preview = gr.Image(label="Repaint Image Preview")
|
714 |
+
|
715 |
+
# 上传文件后更新预览
|
716 |
+
mt_repaint_upload.upload(
|
717 |
+
fn=update_source_preview, # 复用相同的函数
|
718 |
+
inputs=[mt_repaint_upload],
|
719 |
+
outputs=[mt_repaint_preview]
|
720 |
)
|
721 |
|
722 |
# Add run button for Motion Transfer tab
|
723 |
+
mt_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
724 |
|
725 |
+
# Connect to process function, but don't apply tracking
|
726 |
mt_run_btn.click(
|
727 |
fn=process_motion_transfer,
|
728 |
inputs=[
|
729 |
+
source_upload, common_prompt,
|
730 |
+
mt_repaint_option, mt_repaint_upload
|
731 |
],
|
732 |
+
outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
733 |
+
).then(
|
734 |
+
fn=enable_apply_button,
|
735 |
+
inputs=[tracking_video],
|
736 |
+
outputs=[apply_tracking_btn]
|
737 |
)
|
738 |
|
739 |
+
# # Camera Control tab
|
740 |
+
# with gr.TabItem("Camera Control"):
|
741 |
+
# gr.Markdown("## Camera Control")
|
742 |
|
743 |
+
# cc_camera_motion = gr.Textbox(
|
744 |
+
# label="Current Camera Motion Sequence",
|
745 |
+
# placeholder="Your camera motion sequence will appear here...",
|
746 |
+
# interactive=False
|
747 |
+
# )
|
748 |
|
749 |
+
# # Use tabs for different motion types
|
750 |
+
# with gr.Tabs() as cc_motion_tabs:
|
751 |
+
# # Translation tab
|
752 |
+
# with gr.TabItem("Translation (trans)"):
|
753 |
+
# with gr.Row():
|
754 |
+
# cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
|
755 |
+
# cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
|
756 |
+
# cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
|
757 |
|
758 |
+
# with gr.Row():
|
759 |
+
# cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
760 |
+
# cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
761 |
|
762 |
+
# cc_trans_note = gr.Markdown("""
|
763 |
+
# **Translation Notes:**
|
764 |
+
# - Positive X: Move right, Negative X: Move left
|
765 |
+
# - Positive Y: Move down, Negative Y: Move up
|
766 |
+
# - Positive Z: Zoom in, Negative Z: Zoom out
|
767 |
+
# """)
|
768 |
|
769 |
+
# # Add translation button in the Translation tab
|
770 |
+
# cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
|
771 |
|
772 |
+
# # Function to add translation motion
|
773 |
+
# def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
|
774 |
+
# # Format: trans dx dy dz [start_frame end_frame]
|
775 |
+
# frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
|
776 |
+
# new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
|
777 |
|
778 |
+
# # Append to existing motion string with semicolon separator if needed
|
779 |
+
# if current_motion and current_motion.strip():
|
780 |
+
# updated_motion = f"{current_motion}; {new_motion}"
|
781 |
+
# else:
|
782 |
+
# updated_motion = new_motion
|
783 |
|
784 |
+
# return updated_motion
|
785 |
|
786 |
+
# # Connect translation button
|
787 |
+
# cc_add_trans.click(
|
788 |
+
# fn=add_translation_motion,
|
789 |
+
# inputs=[
|
790 |
+
# cc_camera_motion,
|
791 |
+
# cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
|
792 |
+
# ],
|
793 |
+
# outputs=[cc_camera_motion]
|
794 |
+
# )
|
795 |
|
796 |
+
# # Rotation tab
|
797 |
+
# with gr.TabItem("Rotation (rot)"):
|
798 |
+
# with gr.Row():
|
799 |
+
# cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
|
800 |
+
# cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
|
801 |
|
802 |
+
# with gr.Row():
|
803 |
+
# cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
|
804 |
+
# cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
|
805 |
|
806 |
+
# cc_rot_note = gr.Markdown("""
|
807 |
+
# **Rotation Notes:**
|
808 |
+
# - X-axis rotation: Tilt camera up/down
|
809 |
+
# - Y-axis rotation: Pan camera left/right
|
810 |
+
# - Z-axis rotation: Roll camera
|
811 |
+
# """)
|
812 |
|
813 |
+
# # Add rotation button in the Rotation tab
|
814 |
+
# cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
|
815 |
|
816 |
+
# # Function to add rotation motion
|
817 |
+
# def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
|
818 |
+
# # Format: rot axis angle [start_frame end_frame]
|
819 |
+
# frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
|
820 |
+
# new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
|
821 |
|
822 |
+
# # Append to existing motion string with semicolon separator if needed
|
823 |
+
# if current_motion and current_motion.strip():
|
824 |
+
# updated_motion = f"{current_motion}; {new_motion}"
|
825 |
+
# else:
|
826 |
+
# updated_motion = new_motion
|
827 |
|
828 |
+
# return updated_motion
|
829 |
|
830 |
+
# # Connect rotation button
|
831 |
+
# cc_add_rot.click(
|
832 |
+
# fn=add_rotation_motion,
|
833 |
+
# inputs=[
|
834 |
+
# cc_camera_motion,
|
835 |
+
# cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
|
836 |
+
# ],
|
837 |
+
# outputs=[cc_camera_motion]
|
838 |
+
# )
|
839 |
|
840 |
+
# # Add a clear button to reset the motion sequence
|
841 |
+
# cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
|
842 |
|
843 |
+
# def clear_camera_motion():
|
844 |
+
# return ""
|
845 |
|
846 |
+
# cc_clear_motion.click(
|
847 |
+
# fn=clear_camera_motion,
|
848 |
+
# inputs=[],
|
849 |
+
# outputs=[cc_camera_motion]
|
850 |
+
# )
|
851 |
+
|
852 |
+
# cc_tracking_method = gr.Radio(
|
853 |
+
# label="Tracking Method",
|
854 |
+
# choices=["moge", "cotracker"],
|
855 |
+
# value="cotracker"
|
856 |
+
# )
|
857 |
|
858 |
+
# # Add run button for Camera Control tab
|
859 |
+
# cc_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
860 |
|
861 |
+
# # Connect to process function, but don't apply tracking
|
862 |
+
# cc_run_btn.click(
|
863 |
+
# fn=process_camera_control,
|
864 |
+
# inputs=[
|
865 |
+
# source_upload, common_prompt,
|
866 |
+
# cc_camera_motion, cc_tracking_method
|
867 |
+
# ],
|
868 |
+
# outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
869 |
+
# ).then(
|
870 |
+
# fn=enable_apply_button,
|
871 |
+
# inputs=[tracking_video],
|
872 |
+
# outputs=[apply_tracking_btn]
|
873 |
+
# )
|
874 |
|
875 |
+
# # Object Manipulation tab
|
876 |
+
# with gr.TabItem("Object Manipulation"):
|
877 |
+
# gr.Markdown("## Object Manipulation")
|
878 |
+
# om_object_mask = gr.File(
|
879 |
+
# label="Object Mask Image",
|
880 |
+
# file_types=["image"]
|
881 |
+
# )
|
882 |
+
# gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
|
883 |
+
# om_object_motion = gr.Dropdown(
|
884 |
+
# label="Object Motion Type",
|
885 |
+
# choices=["up", "down", "left", "right", "front", "back", "rot"],
|
886 |
+
# value="up"
|
887 |
+
# )
|
888 |
+
# om_tracking_method = gr.Radio(
|
889 |
+
# label="Tracking Method",
|
890 |
+
# choices=["moge", "cotracker"],
|
891 |
+
# value="cotracker"
|
892 |
+
# )
|
893 |
|
894 |
+
# # Add run button for Object Manipulation tab
|
895 |
+
# om_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
|
896 |
|
897 |
+
# # Connect to process function, but don't apply tracking
|
898 |
+
# om_run_btn.click(
|
899 |
+
# fn=process_object_manipulation,
|
900 |
+
# inputs=[
|
901 |
+
# source_upload, common_prompt,
|
902 |
+
# om_object_motion, om_object_mask, om_tracking_method
|
903 |
+
# ],
|
904 |
+
# outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
905 |
+
# ).then(
|
906 |
+
# fn=enable_apply_button,
|
907 |
+
# inputs=[tracking_video],
|
908 |
+
# outputs=[apply_tracking_btn]
|
909 |
+
# )
|
910 |
|
911 |
+
# # Animating meshes to video tab
|
912 |
+
# with gr.TabItem("Animating meshes to video"):
|
913 |
+
# gr.Markdown("## Mesh Animation to Video")
|
914 |
+
# gr.Markdown("""
|
915 |
+
# Note: Currently only supports tracking videos generated with Blender (version > 4.0).
|
916 |
+
# Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
|
917 |
+
# """)
|
918 |
+
# ma_tracking_video = gr.File(
|
919 |
+
# label="Tracking Video",
|
920 |
+
# file_types=["video"],
|
921 |
+
# # 添加 change 事件处理器,当上传文件时自动激活 Generate Video 按钮
|
922 |
+
# elem_id="ma_tracking_video"
|
923 |
+
# )
|
924 |
+
# gr.Markdown("Tracking video needs to be generated from Blender")
|
925 |
|
926 |
+
# # Simplified controls - Radio buttons for Yes/No and separate file upload
|
927 |
+
# with gr.Row():
|
928 |
+
# ma_repaint_option = gr.Radio(
|
929 |
+
# label="Repaint First Frame",
|
930 |
+
# choices=["No", "Yes"],
|
931 |
+
# value="No"
|
932 |
+
# )
|
933 |
+
# gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
|
934 |
+
# # Custom image uploader (always visible)
|
935 |
+
# ma_repaint_image = gr.File(
|
936 |
+
# label="Custom Repaint Image",
|
937 |
+
# file_types=["image"]
|
938 |
+
# )
|
939 |
|
940 |
+
# # 修改按钮名称为 "Apply Repaint"
|
941 |
+
# ma_run_btn = gr.Button("Apply Repaint", variant="primary", size="lg")
|
942 |
|
943 |
+
# # 添加 tracking video 上传事件处理
|
944 |
+
# def handle_tracking_upload(file):
|
945 |
+
# if file is not None:
|
946 |
+
# tracking_path = save_uploaded_file(file)
|
947 |
+
# if tracking_path:
|
948 |
+
# return tracking_path, gr.update(interactive=True)
|
949 |
+
# return None, gr.update(interactive=False)
|
950 |
+
|
951 |
+
# # 当上传 tracking video 时,直接显示并激活 Generate Video 按钮
|
952 |
+
# ma_tracking_video.change(
|
953 |
+
# fn=handle_tracking_upload,
|
954 |
+
# inputs=[ma_tracking_video],
|
955 |
+
# outputs=[tracking_video, apply_tracking_btn]
|
956 |
+
# )
|
957 |
+
|
958 |
+
# # 修改 process_mesh_animation 函数的行为
|
959 |
+
# def process_mesh_animation_repaint(source, prompt, ma_repaint_option, ma_repaint_image):
|
960 |
+
# """只处理重绘部分,不处理跟踪视频"""
|
961 |
+
# try:
|
962 |
+
# # 保存上传的文件
|
963 |
+
# input_video_path = save_uploaded_file(source)
|
964 |
+
# if input_video_path is None:
|
965 |
+
# return None, None, None, None
|
966 |
+
|
967 |
+
# das = get_das_pipeline()
|
968 |
+
# video_tensor, fps, is_video = load_media(input_video_path)
|
969 |
+
# das.fps = fps
|
970 |
+
|
971 |
+
# repaint_img_tensor = None
|
972 |
+
# if ma_repaint_image is not None:
|
973 |
+
# repaint_path = save_uploaded_file(ma_repaint_image)
|
974 |
+
# repaint_img_tensor, _, _ = load_media(repaint_path)
|
975 |
+
# repaint_img_tensor = repaint_img_tensor[0]
|
976 |
+
# elif ma_repaint_option == "Yes":
|
977 |
+
# repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
|
978 |
+
# repaint_img_tensor = repainter.repaint(
|
979 |
+
# video_tensor[0],
|
980 |
+
# prompt=prompt,
|
981 |
+
# depth_path=None
|
982 |
+
# )
|
983 |
+
|
984 |
+
# # 返回处理结果,但不包括跟踪视频路径
|
985 |
+
# return video_tensor, None, repaint_img_tensor, fps
|
986 |
+
# except Exception as e:
|
987 |
+
# import traceback
|
988 |
+
# print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
989 |
+
# return None, None, None, None
|
990 |
+
|
991 |
+
# # 连接到修改后的处理函数
|
992 |
+
# ma_run_btn.click(
|
993 |
+
# fn=process_mesh_animation_repaint,
|
994 |
+
# inputs=[
|
995 |
+
# source_upload, common_prompt,
|
996 |
+
# ma_repaint_option, ma_repaint_image
|
997 |
+
# ],
|
998 |
+
# outputs=[video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
|
999 |
+
# )
|
1000 |
+
|
1001 |
+
# 在所有 UI 元素定义之后,添加 Examples 组件
|
1002 |
+
examples_list = load_examples()
|
1003 |
+
if examples_list:
|
1004 |
+
with gr.Blocks() as examples_block:
|
1005 |
+
gr.Examples(
|
1006 |
+
examples=examples_list,
|
1007 |
+
inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
|
1008 |
+
outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
|
1009 |
+
fn=lambda *args: args, # 简单地返回输入作为输出
|
1010 |
+
cache_examples=True,
|
1011 |
+
label="Examples"
|
1012 |
+
)
|
1013 |
|
1014 |
# Launch interface
|
1015 |
if __name__ == "__main__":
|
|
|
1019 |
print("Creating public link for remote access")
|
1020 |
|
1021 |
# Launch interface
|
1022 |
+
demo.launch(share=args.share, server_port=args.port)
|