EXCAI commited on
Commit
9db1e11
·
1 Parent(s): 79dc1ad
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +469 -281
.gitignore CHANGED
@@ -197,4 +197,5 @@ slurm-*.out
197
  .vscode
198
 
199
  data/
200
- tmp/
 
 
197
  .vscode
198
 
199
  data/
200
+ tmp/
201
+ .gradio/
app.py CHANGED
@@ -33,6 +33,9 @@ from submodules.MoGe.moge.model import MoGeModel
33
  from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
34
  from submodules.vggt.vggt.models.vggt import VGGT
35
 
 
 
 
36
  # Parse command line arguments
37
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
38
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
@@ -82,9 +85,9 @@ def load_media(media_path, max_frames=49, transform=None):
82
 
83
  # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
84
  if duration > 6.0:
85
- sampling_fps = 8 # 8 frames per second
86
- frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
87
- fps = sampling_fps
88
  # Cases 2 and 3: Video shorter than 6 seconds
89
  else:
90
  # Load all frames
@@ -195,10 +198,10 @@ def get_vggt_model():
195
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
196
  """Process video motion transfer task"""
197
  try:
198
- # Save uploaded files
199
  input_video_path = save_uploaded_file(source)
200
  if input_video_path is None:
201
- return None, None
202
 
203
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
204
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
@@ -253,28 +256,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
253
  tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
254
  print('Export tracking video via cotracker')
255
 
256
- output_path = das.apply_tracking(
257
- video_tensor=video_tensor,
258
- fps=fps, # 使用 load_media 返回的 fps
259
- tracking_tensor=tracking_tensor,
260
- img_cond_tensor=repaint_img_tensor,
261
- prompt=prompt,
262
- checkpoint_path=DEFAULT_MODEL_PATH
263
- )
264
-
265
- return tracking_path, output_path
266
  except Exception as e:
267
  import traceback
268
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
269
- return None, None
270
 
271
  def process_camera_control(source, prompt, camera_motion, tracking_method):
272
  """Process camera control task"""
273
  try:
274
- # Save uploaded files
275
  input_media_path = save_uploaded_file(source)
276
  if input_media_path is None:
277
- return None, None
278
 
279
  print(f"DEBUG: Camera motion: '{camera_motion}'")
280
  print(f"DEBUG: Tracking method: '{tracking_method}'")
@@ -317,24 +312,8 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
317
  # 使用在CPU上运行的cotracker
318
  pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
319
 
320
- t, c, h, w = video_tensor.shape
321
- new_width = 518
322
- new_height = round(h * (new_width / w) / 14) * 14
323
- resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
324
- video_vggt = resize_transform(video_tensor) # [T, C, H, W]
325
-
326
- if new_height > 518:
327
- start_y = (new_height - 518) // 2
328
- video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
329
-
330
- vggt_model = get_vggt_model()
331
-
332
- with torch.no_grad():
333
- with torch.cuda.amp.autocast(dtype=das.dtype):
334
- video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
335
- aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
336
-
337
- extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
338
 
339
  cam_motion.set_intr(intr)
340
  cam_motion.set_extr(extr)
@@ -345,23 +324,15 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
345
  pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
346
  print("Camera motion applied")
347
 
348
- tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, None)
349
  print('Export tracking video via cotracker')
350
 
351
- output_path = das.apply_tracking(
352
- video_tensor=video_tensor,
353
- fps=fps, # 使用 load_media 返回的 fps
354
- tracking_tensor=tracking_tensor,
355
- img_cond_tensor=repaint_img_tensor,
356
- prompt=prompt,
357
- checkpoint_path=DEFAULT_MODEL_PATH
358
- )
359
-
360
- return tracking_path, output_path
361
  except Exception as e:
362
  import traceback
363
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
364
- return None, None
365
 
366
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
367
  """Process object manipulation task"""
@@ -369,12 +340,12 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
369
  # Save uploaded files
370
  input_image_path = save_uploaded_file(source)
371
  if input_image_path is None:
372
- return None, None
373
 
374
  object_mask_path = save_uploaded_file(object_mask)
375
  if object_mask_path is None:
376
  print("Object mask not provided")
377
- return None, None
378
 
379
  das = get_das_pipeline()
380
  video_tensor, fps, is_video = load_media(input_image_path)
@@ -424,24 +395,8 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
424
  # 使用在CPU上运行的cotracker
425
  pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
426
 
427
- t, c, h, w = video_tensor.shape
428
- new_width = 518
429
- new_height = round(h * (new_width / w) / 14) * 14
430
- resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
431
- video_vggt = resize_transform(video_tensor) # [T, C, H, W]
432
-
433
- if new_height > 518:
434
- start_y = (new_height - 518) // 2
435
- video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
436
-
437
- vggt_model = get_vggt_model()
438
-
439
- with torch.no_grad():
440
- with torch.cuda.amp.autocast(dtype=das.dtype):
441
- video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
442
- aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
443
-
444
- extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
445
 
446
  pred_tracks = motion_generator.apply_motion(
447
  pred_tracks=pred_tracks.squeeze(),
@@ -453,23 +408,15 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
453
  )
454
  print(f"Object motion '{object_motion}' applied using provided mask")
455
 
456
- tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), None)
457
  print('Export tracking video via cotracker')
458
 
459
- output_path = das.apply_tracking(
460
- video_tensor=video_tensor,
461
- fps=fps, # 使用 load_media 返回的 fps
462
- tracking_tensor=tracking_tensor,
463
- img_cond_tensor=repaint_img_tensor,
464
- prompt=prompt,
465
- checkpoint_path=DEFAULT_MODEL_PATH
466
- )
467
-
468
- return tracking_path, output_path
469
  except Exception as e:
470
  import traceback
471
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
472
- return None, None
473
 
474
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
475
  """Process mesh animation task"""
@@ -477,11 +424,11 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
477
  # Save uploaded files
478
  input_video_path = save_uploaded_file(source)
479
  if input_video_path is None:
480
- return None, None
481
 
482
  tracking_video_path = save_uploaded_file(tracking_video)
483
  if tracking_video_path is None:
484
- return None, None
485
 
486
  das = get_das_pipeline()
487
  video_tensor, fps, is_video = load_media(input_video_path)
@@ -494,7 +441,6 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
494
  repaint_img_tensor, _, _ = load_media(repaint_path)
495
  repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
496
  elif ma_repaint_option == "Yes":
497
-
498
  repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
499
  repaint_img_tensor = repainter.repaint(
500
  video_tensor[0],
@@ -502,20 +448,12 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
502
  depth_path=None
503
  )
504
 
505
- output_path = das.apply_tracking(
506
- video_tensor=video_tensor,
507
- fps=fps, # 使用 load_media 返回的 fps
508
- tracking_tensor=tracking_tensor,
509
- img_cond_tensor=repaint_img_tensor,
510
- prompt=prompt,
511
- checkpoint_path=DEFAULT_MODEL_PATH
512
- )
513
-
514
- return tracking_video_path, output_path
515
  except Exception as e:
516
  import traceback
517
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
518
- return None, None
519
 
520
  def generate_tracking_cotracker(video_tensor, density=30):
521
  """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
@@ -569,22 +507,192 @@ def generate_tracking_cotracker(video_tensor, density=30):
569
  # 将结果返回
570
  return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  # Create Gradio interface with updated layout
573
  with gr.Blocks(title="Diffusion as Shader") as demo:
574
  gr.Markdown("# Diffusion as Shader Web UI")
575
  gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
576
 
 
 
 
 
 
 
577
  with gr.Row():
578
  left_column = gr.Column(scale=1)
579
  right_column = gr.Column(scale=1)
580
 
581
  with right_column:
582
- output_video = gr.Video(label="Generated Video")
583
  tracking_video = gr.Video(label="Tracking Video")
 
 
 
 
584
 
585
  with left_column:
586
- source = gr.File(label="Source", file_types=["image", "video"])
587
- common_prompt = gr.Textbox(label="Prompt", lines=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  gr.Markdown(f"**Using GPU: {GPU_ID}**")
589
 
590
  with gr.Tabs() as task_tabs:
@@ -600,228 +708,308 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
600
  value="No"
601
  )
602
  gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
603
- # Custom image uploader (always visible)
604
- mt_repaint_image = gr.File(
605
- label="Custom Repaint Image",
606
- file_types=["image"]
 
 
 
 
 
607
  )
608
 
609
  # Add run button for Motion Transfer tab
610
- mt_run_btn = gr.Button("Run Motion Transfer", variant="primary", size="lg")
611
 
612
- # Connect to process function
613
  mt_run_btn.click(
614
  fn=process_motion_transfer,
615
  inputs=[
616
- source, common_prompt,
617
- mt_repaint_option, mt_repaint_image
618
  ],
619
- outputs=[tracking_video, output_video]
 
 
 
 
620
  )
621
 
622
- # Camera Control tab
623
- with gr.TabItem("Camera Control"):
624
- gr.Markdown("## Camera Control")
625
 
626
- cc_camera_motion = gr.Textbox(
627
- label="Current Camera Motion Sequence",
628
- placeholder="Your camera motion sequence will appear here...",
629
- interactive=False
630
- )
631
 
632
- # Use tabs for different motion types
633
- with gr.Tabs() as cc_motion_tabs:
634
- # Translation tab
635
- with gr.TabItem("Translation (trans)"):
636
- with gr.Row():
637
- cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
638
- cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
639
- cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
640
 
641
- with gr.Row():
642
- cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
643
- cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
644
 
645
- cc_trans_note = gr.Markdown("""
646
- **Translation Notes:**
647
- - Positive X: Move right, Negative X: Move left
648
- - Positive Y: Move down, Negative Y: Move up
649
- - Positive Z: Zoom in, Negative Z: Zoom out
650
- """)
651
 
652
- # Add translation button in the Translation tab
653
- cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
654
 
655
- # Function to add translation motion
656
- def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
657
- # Format: trans dx dy dz [start_frame end_frame]
658
- frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
659
- new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
660
 
661
- # Append to existing motion string with semicolon separator if needed
662
- if current_motion and current_motion.strip():
663
- updated_motion = f"{current_motion}; {new_motion}"
664
- else:
665
- updated_motion = new_motion
666
 
667
- return updated_motion
668
 
669
- # Connect translation button
670
- cc_add_trans.click(
671
- fn=add_translation_motion,
672
- inputs=[
673
- cc_camera_motion,
674
- cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
675
- ],
676
- outputs=[cc_camera_motion]
677
- )
678
 
679
- # Rotation tab
680
- with gr.TabItem("Rotation (rot)"):
681
- with gr.Row():
682
- cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
683
- cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
684
 
685
- with gr.Row():
686
- cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
687
- cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
688
 
689
- cc_rot_note = gr.Markdown("""
690
- **Rotation Notes:**
691
- - X-axis rotation: Tilt camera up/down
692
- - Y-axis rotation: Pan camera left/right
693
- - Z-axis rotation: Roll camera
694
- """)
695
 
696
- # Add rotation button in the Rotation tab
697
- cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
698
 
699
- # Function to add rotation motion
700
- def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
701
- # Format: rot axis angle [start_frame end_frame]
702
- frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
703
- new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
704
 
705
- # Append to existing motion string with semicolon separator if needed
706
- if current_motion and current_motion.strip():
707
- updated_motion = f"{current_motion}; {new_motion}"
708
- else:
709
- updated_motion = new_motion
710
 
711
- return updated_motion
712
 
713
- # Connect rotation button
714
- cc_add_rot.click(
715
- fn=add_rotation_motion,
716
- inputs=[
717
- cc_camera_motion,
718
- cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
719
- ],
720
- outputs=[cc_camera_motion]
721
- )
722
 
723
- # Add a clear button to reset the motion sequence
724
- cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
725
 
726
- def clear_camera_motion():
727
- return ""
728
 
729
- cc_clear_motion.click(
730
- fn=clear_camera_motion,
731
- inputs=[],
732
- outputs=[cc_camera_motion]
733
- )
734
-
735
- cc_tracking_method = gr.Radio(
736
- label="Tracking Method",
737
- choices=["moge", "cotracker"],
738
- value="cotracker"
739
- )
740
 
741
- # Add run button for Camera Control tab
742
- cc_run_btn = gr.Button("Run Camera Control", variant="primary", size="lg")
743
 
744
- # Connect to process function
745
- cc_run_btn.click(
746
- fn=process_camera_control,
747
- inputs=[
748
- source, common_prompt,
749
- cc_camera_motion, cc_tracking_method
750
- ],
751
- outputs=[tracking_video, output_video]
752
- )
 
 
 
 
753
 
754
- # Object Manipulation tab
755
- with gr.TabItem("Object Manipulation"):
756
- gr.Markdown("## Object Manipulation")
757
- om_object_mask = gr.File(
758
- label="Object Mask Image",
759
- file_types=["image"]
760
- )
761
- gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
762
- om_object_motion = gr.Dropdown(
763
- label="Object Motion Type",
764
- choices=["up", "down", "left", "right", "front", "back", "rot"],
765
- value="up"
766
- )
767
- om_tracking_method = gr.Radio(
768
- label="Tracking Method",
769
- choices=["moge", "cotracker"],
770
- value="cotracker"
771
- )
772
 
773
- # Add run button for Object Manipulation tab
774
- om_run_btn = gr.Button("Run Object Manipulation", variant="primary", size="lg")
775
 
776
- # Connect to process function
777
- om_run_btn.click(
778
- fn=process_object_manipulation,
779
- inputs=[
780
- source, common_prompt,
781
- om_object_motion, om_object_mask, om_tracking_method
782
- ],
783
- outputs=[tracking_video, output_video]
784
- )
 
 
 
 
785
 
786
- # Animating meshes to video tab
787
- with gr.TabItem("Animating meshes to video"):
788
- gr.Markdown("## Mesh Animation to Video")
789
- gr.Markdown("""
790
- Note: Currently only supports tracking videos generated with Blender (version > 4.0).
791
- Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
792
- """)
793
- ma_tracking_video = gr.File(
794
- label="Tracking Video",
795
- file_types=["video"]
796
- )
797
- gr.Markdown("Tracking video needs to be generated from Blender")
 
 
798
 
799
- # Simplified controls - Radio buttons for Yes/No and separate file upload
800
- with gr.Row():
801
- ma_repaint_option = gr.Radio(
802
- label="Repaint First Frame",
803
- choices=["No", "Yes"],
804
- value="No"
805
- )
806
- gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
807
- # Custom image uploader (always visible)
808
- ma_repaint_image = gr.File(
809
- label="Custom Repaint Image",
810
- file_types=["image"]
811
- )
812
 
813
- # Add run button for Mesh Animation tab
814
- ma_run_btn = gr.Button("Run Mesh Animation", variant="primary", size="lg")
815
 
816
- # Connect to process function
817
- ma_run_btn.click(
818
- fn=process_mesh_animation,
819
- inputs=[
820
- source, common_prompt,
821
- ma_tracking_video, ma_repaint_option, ma_repaint_image
822
- ],
823
- outputs=[tracking_video, output_video]
824
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
  # Launch interface
827
  if __name__ == "__main__":
@@ -831,4 +1019,4 @@ if __name__ == "__main__":
831
  print("Creating public link for remote access")
832
 
833
  # Launch interface
834
- demo.launch(share=args.share, server_port=args.port)
 
33
  from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
34
  from submodules.vggt.vggt.models.vggt import VGGT
35
 
36
+ import torch._dynamo
37
+ torch._dynamo.config.suppress_errors = True
38
+
39
  # Parse command line arguments
40
  parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
41
  parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
 
85
 
86
  # Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
87
  if duration > 6.0:
88
+ # 使用 max_frames 参数而不是 sampling_fps
89
+ frames = load_video(media_path, max_frames=max_frames)
90
+ fps = max_frames / 6.0 # 计算等效的 fps
91
  # Cases 2 and 3: Video shorter than 6 seconds
92
  else:
93
  # Load all frames
 
198
  def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
199
  """Process video motion transfer task"""
200
  try:
201
+ # 保存上传的文件
202
  input_video_path = save_uploaded_file(source)
203
  if input_video_path is None:
204
+ return None, None, None
205
 
206
  print(f"DEBUG: Repaint option: {mt_repaint_option}")
207
  print(f"DEBUG: Repaint image: {mt_repaint_image}")
 
256
  tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
257
  print('Export tracking video via cotracker')
258
 
259
+ # 返回处理结果,但不应用跟踪
260
+ return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
 
 
 
 
 
 
 
 
261
  except Exception as e:
262
  import traceback
263
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
264
+ return None, None, None, None, None
265
 
266
  def process_camera_control(source, prompt, camera_motion, tracking_method):
267
  """Process camera control task"""
268
  try:
269
+ # 保存上传的文件
270
  input_media_path = save_uploaded_file(source)
271
  if input_media_path is None:
272
+ return None, None, None
273
 
274
  print(f"DEBUG: Camera motion: '{camera_motion}'")
275
  print(f"DEBUG: Tracking method: '{tracking_method}'")
 
312
  # 使用在CPU上运行的cotracker
313
  pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
314
 
315
+ # 使用封装的 VGGT 处理函数
316
+ extr, intr = process_vggt(video_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  cam_motion.set_intr(intr)
319
  cam_motion.set_extr(extr)
 
324
  pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
325
  print("Camera motion applied")
326
 
327
+ tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
328
  print('Export tracking video via cotracker')
329
 
330
+ # 返回处理结果,但不应用跟踪
331
+ return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
 
 
 
 
 
 
 
 
332
  except Exception as e:
333
  import traceback
334
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
335
+ return None, None, None, None, None
336
 
337
  def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
338
  """Process object manipulation task"""
 
340
  # Save uploaded files
341
  input_image_path = save_uploaded_file(source)
342
  if input_image_path is None:
343
+ return None, None, None, None, None
344
 
345
  object_mask_path = save_uploaded_file(object_mask)
346
  if object_mask_path is None:
347
  print("Object mask not provided")
348
+ return None, None, None, None, None
349
 
350
  das = get_das_pipeline()
351
  video_tensor, fps, is_video = load_media(input_image_path)
 
395
  # 使用在CPU上运行的cotracker
396
  pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
397
 
398
+ # 使用封装的 VGGT 处理函数
399
+ extr, intr = process_vggt(video_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  pred_tracks = motion_generator.apply_motion(
402
  pred_tracks=pred_tracks.squeeze(),
 
408
  )
409
  print(f"Object motion '{object_motion}' applied using provided mask")
410
 
411
+ tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), pred_visibility)
412
  print('Export tracking video via cotracker')
413
 
414
+ # 返回处理结果,但不应用跟踪
415
+ return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
 
 
 
 
 
 
 
 
416
  except Exception as e:
417
  import traceback
418
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
419
+ return None, None, None, None, None
420
 
421
  def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
422
  """Process mesh animation task"""
 
424
  # Save uploaded files
425
  input_video_path = save_uploaded_file(source)
426
  if input_video_path is None:
427
+ return None, None, None, None, None
428
 
429
  tracking_video_path = save_uploaded_file(tracking_video)
430
  if tracking_video_path is None:
431
+ return None, None, None, None, None
432
 
433
  das = get_das_pipeline()
434
  video_tensor, fps, is_video = load_media(input_video_path)
 
441
  repaint_img_tensor, _, _ = load_media(repaint_path)
442
  repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
443
  elif ma_repaint_option == "Yes":
 
444
  repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
445
  repaint_img_tensor = repainter.repaint(
446
  video_tensor[0],
 
448
  depth_path=None
449
  )
450
 
451
+ # 直接返回上传的跟踪视频路径,而不是生成新的跟踪视频
452
+ return tracking_video_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
 
 
 
 
 
 
 
 
453
  except Exception as e:
454
  import traceback
455
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
456
+ return None, None, None, None, None
457
 
458
  def generate_tracking_cotracker(video_tensor, density=30):
459
  """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
 
507
  # 将结果返回
508
  return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
509
 
510
+ @spaces.GPU(duration=240)
511
+ def apply_tracking_unified(video_tensor, tracking_tensor, repaint_img_tensor, prompt, fps):
512
+ """统一的应用跟踪函数"""
513
+ try:
514
+ if video_tensor is None or tracking_tensor is None:
515
+ return None
516
+
517
+ das = get_das_pipeline()
518
+ output_path = das.apply_tracking(
519
+ video_tensor=video_tensor,
520
+ fps=fps,
521
+ tracking_tensor=tracking_tensor,
522
+ img_cond_tensor=repaint_img_tensor,
523
+ prompt=prompt,
524
+ checkpoint_path=DEFAULT_MODEL_PATH
525
+ )
526
+
527
+ print(f"生成的视频路径: {output_path}")
528
+
529
+ # 确保返回的是绝对路径
530
+ if output_path and not os.path.isabs(output_path):
531
+ output_path = os.path.abspath(output_path)
532
+
533
+ # 检查文件是否存在
534
+ if output_path and os.path.exists(output_path):
535
+ print(f"文件存在,大小: {os.path.getsize(output_path)} 字节")
536
+ return output_path
537
+ else:
538
+ print(f"警告: 输出文件不存在或路径无效: {output_path}")
539
+ return None
540
+ except Exception as e:
541
+ import traceback
542
+ print(f"Apply tracking failed: {str(e)}\n{traceback.format_exc()}")
543
+ return None
544
+
545
+ # 添加在 apply_tracking_unified 函数之后,Gradio 界面定义之前
546
+
547
+ def enable_apply_button(tracking_result):
548
+ """当跟踪视频生成后启用应用按钮"""
549
+ if tracking_result is not None:
550
+ return gr.update(interactive=True)
551
+ return gr.update(interactive=False)
552
+
553
+ @spaces.GPU
554
+ def process_vggt(video_tensor):
555
+ vggt_model = get_vggt_model()
556
+
557
+ t, c, h, w = video_tensor.shape
558
+ new_width = 518
559
+ new_height = round(h * (new_width / w) / 14) * 14
560
+ resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
561
+ video_vggt = resize_transform(video_tensor) # [T, C, H, W]
562
+
563
+ if new_height > 518:
564
+ start_y = (new_height - 518) // 2
565
+ video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
566
+
567
+ with torch.no_grad():
568
+ with torch.cuda.amp.autocast(dtype=torch.float16):
569
+ video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
570
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to("cuda"))
571
+
572
+ extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
573
+
574
+ return extr, intr
575
+
576
+ def load_examples():
577
+ """加载示例文件路径"""
578
+ samples_dir = os.path.join(project_root, "samples")
579
+ if not os.path.exists(samples_dir):
580
+ print(f"Warning: Samples directory not found at {samples_dir}")
581
+ return []
582
+
583
+ examples_list = []
584
+
585
+ # 为每个示例集创建一个示例项
586
+ # 示例1
587
+ example1 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
588
+ for filename in os.listdir(samples_dir):
589
+ if filename.startswith("sample1_"):
590
+ if filename.endswith("_raw.mp4"):
591
+ example1[0] = os.path.join(samples_dir, filename)
592
+ elif filename.endswith("_repaint.png"):
593
+ example1[1] = os.path.join(samples_dir, filename)
594
+ elif filename.endswith("_tracking.mp4"):
595
+ example1[3] = os.path.join(samples_dir, filename)
596
+ elif filename.endswith("_result.mp4"):
597
+ example1[4] = os.path.join(samples_dir, filename)
598
+
599
+ # 设置示例1的提示文本
600
+ example1[2] = "a rocket lifts off from the table and smoke erupt from its bottom."
601
+
602
+ # 示例2
603
+ example2 = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
604
+ for filename in os.listdir(samples_dir):
605
+ if filename.startswith("sample2_"):
606
+ if filename.endswith("_raw.mp4"):
607
+ example2[0] = os.path.join(samples_dir, filename)
608
+ elif filename.endswith("_repaint.png"):
609
+ example2[1] = os.path.join(samples_dir, filename)
610
+ elif filename.endswith("_tracking.mp4"):
611
+ example2[3] = os.path.join(samples_dir, filename)
612
+ elif filename.endswith("_result.mp4"):
613
+ example2[4] = os.path.join(samples_dir, filename)
614
+
615
+ # 设置示例2的提示文本
616
+ example2[2] = "A wonderful bright old-fasion red car is riding from left to right sun light is shining on the car, its reflection glittering. In the background is a deserted city in the noon, the roads and buildings are covered with green vegetation."
617
+
618
+ # 添加示例到列表
619
+ if example1[0] is not None and example1[3] is not None:
620
+ examples_list.append(example1)
621
+
622
+ if example2[0] is not None and example2[3] is not None:
623
+ examples_list.append(example2)
624
+
625
+ # 添加其他示例(如果有)
626
+ sample_prefixes = set()
627
+ for filename in os.listdir(samples_dir):
628
+ if filename.endswith(('.mp4', '.png')):
629
+ prefix = filename.split('_')[0]
630
+ if prefix not in ["sample1", "sample2"]:
631
+ sample_prefixes.add(prefix)
632
+
633
+ for prefix in sorted(sample_prefixes):
634
+ example = [None] * 5 # [source, repaint_image, prompt, tracking_video, result_video]
635
+ for filename in os.listdir(samples_dir):
636
+ if filename.startswith(f"{prefix}_"):
637
+ if filename.endswith("_raw.mp4"):
638
+ example[0] = os.path.join(samples_dir, filename)
639
+ elif filename.endswith("_repaint.png"):
640
+ example[1] = os.path.join(samples_dir, filename)
641
+ elif filename.endswith("_tracking.mp4"):
642
+ example[3] = os.path.join(samples_dir, filename)
643
+ elif filename.endswith("_result.mp4"):
644
+ example[4] = os.path.join(samples_dir, filename)
645
+
646
+ # 添加默认提示文本
647
+ example[2] = "A beautiful scene"
648
+
649
+ # 只有当至少有源文件和跟踪视频时才添加示例
650
+ if example[0] is not None and example[3] is not None:
651
+ examples_list.append(example)
652
+
653
+ return examples_list
654
+
655
  # Create Gradio interface with updated layout
656
  with gr.Blocks(title="Diffusion as Shader") as demo:
657
  gr.Markdown("# Diffusion as Shader Web UI")
658
  gr.Markdown("### [Project Page](https://igl-hkust.github.io/das/) | [GitHub](https://github.com/IGL-HKUST/DiffusionAsShader)")
659
 
660
+ # 创建隐藏状态变量来存储中间结果
661
+ video_tensor_state = gr.State(None)
662
+ tracking_tensor_state = gr.State(None)
663
+ repaint_img_tensor_state = gr.State(None)
664
+ fps_state = gr.State(None)
665
+
666
  with gr.Row():
667
  left_column = gr.Column(scale=1)
668
  right_column = gr.Column(scale=1)
669
 
670
  with right_column:
 
671
  tracking_video = gr.Video(label="Tracking Video")
672
+
673
+ # 初始状态下按钮不可用
674
+ apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False)
675
+ output_video = gr.Video(label="Generated Video")
676
 
677
  with left_column:
678
+ source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"])
679
+ source_preview = gr.Video(label="Source Preview")
680
+ gr.Markdown("Upload a video or image, We will extract the motion and space structure from it")
681
+
682
+ # 上传文件后更新预览
683
+ def update_source_preview(file):
684
+ if file is None:
685
+ return None
686
+ path = save_uploaded_file(file)
687
+ return path
688
+
689
+ source_upload.upload(
690
+ fn=update_source_preview,
691
+ inputs=[source_upload],
692
+ outputs=[source_preview]
693
+ )
694
+
695
+ common_prompt = gr.Textbox(label="2. Prompt: Describe the scene and the motion you want to create", lines=2)
696
  gr.Markdown(f"**Using GPU: {GPU_ID}**")
697
 
698
  with gr.Tabs() as task_tabs:
 
708
  value="No"
709
  )
710
  gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
711
+
712
+ mt_repaint_upload = gr.UploadButton("3. Upload Repaint Image (Optional)", file_types=["image"])
713
+ mt_repaint_preview = gr.Image(label="Repaint Image Preview")
714
+
715
+ # 上传文件后更新预览
716
+ mt_repaint_upload.upload(
717
+ fn=update_source_preview, # 复用相同的函数
718
+ inputs=[mt_repaint_upload],
719
+ outputs=[mt_repaint_preview]
720
  )
721
 
722
  # Add run button for Motion Transfer tab
723
+ mt_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
724
 
725
+ # Connect to process function, but don't apply tracking
726
  mt_run_btn.click(
727
  fn=process_motion_transfer,
728
  inputs=[
729
+ source_upload, common_prompt,
730
+ mt_repaint_option, mt_repaint_upload
731
  ],
732
+ outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
733
+ ).then(
734
+ fn=enable_apply_button,
735
+ inputs=[tracking_video],
736
+ outputs=[apply_tracking_btn]
737
  )
738
 
739
+ # # Camera Control tab
740
+ # with gr.TabItem("Camera Control"):
741
+ # gr.Markdown("## Camera Control")
742
 
743
+ # cc_camera_motion = gr.Textbox(
744
+ # label="Current Camera Motion Sequence",
745
+ # placeholder="Your camera motion sequence will appear here...",
746
+ # interactive=False
747
+ # )
748
 
749
+ # # Use tabs for different motion types
750
+ # with gr.Tabs() as cc_motion_tabs:
751
+ # # Translation tab
752
+ # with gr.TabItem("Translation (trans)"):
753
+ # with gr.Row():
754
+ # cc_trans_x = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="X-axis Movement")
755
+ # cc_trans_y = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Y-axis Movement")
756
+ # cc_trans_z = gr.Slider(minimum=-1.0, maximum=1.0, value=0.0, step=0.05, label="Z-axis Movement (depth)")
757
 
758
+ # with gr.Row():
759
+ # cc_trans_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
760
+ # cc_trans_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
761
 
762
+ # cc_trans_note = gr.Markdown("""
763
+ # **Translation Notes:**
764
+ # - Positive X: Move right, Negative X: Move left
765
+ # - Positive Y: Move down, Negative Y: Move up
766
+ # - Positive Z: Zoom in, Negative Z: Zoom out
767
+ # """)
768
 
769
+ # # Add translation button in the Translation tab
770
+ # cc_add_trans = gr.Button("Add Camera Translation", variant="secondary")
771
 
772
+ # # Function to add translation motion
773
+ # def add_translation_motion(current_motion, trans_x, trans_y, trans_z, trans_start, trans_end):
774
+ # # Format: trans dx dy dz [start_frame end_frame]
775
+ # frame_range = f" {int(trans_start)} {int(trans_end)}" if trans_start != 0 or trans_end != 48 else ""
776
+ # new_motion = f"trans {trans_x:.2f} {trans_y:.2f} {trans_z:.2f}{frame_range}"
777
 
778
+ # # Append to existing motion string with semicolon separator if needed
779
+ # if current_motion and current_motion.strip():
780
+ # updated_motion = f"{current_motion}; {new_motion}"
781
+ # else:
782
+ # updated_motion = new_motion
783
 
784
+ # return updated_motion
785
 
786
+ # # Connect translation button
787
+ # cc_add_trans.click(
788
+ # fn=add_translation_motion,
789
+ # inputs=[
790
+ # cc_camera_motion,
791
+ # cc_trans_x, cc_trans_y, cc_trans_z, cc_trans_start, cc_trans_end
792
+ # ],
793
+ # outputs=[cc_camera_motion]
794
+ # )
795
 
796
+ # # Rotation tab
797
+ # with gr.TabItem("Rotation (rot)"):
798
+ # with gr.Row():
799
+ # cc_rot_axis = gr.Dropdown(choices=["x", "y", "z"], value="y", label="Rotation Axis")
800
+ # cc_rot_angle = gr.Slider(minimum=-30, maximum=30, value=5, step=1, label="Rotation Angle (degrees)")
801
 
802
+ # with gr.Row():
803
+ # cc_rot_start = gr.Number(minimum=0, maximum=48, value=0, step=1, label="Start Frame", precision=0)
804
+ # cc_rot_end = gr.Number(minimum=0, maximum=48, value=48, step=1, label="End Frame", precision=0)
805
 
806
+ # cc_rot_note = gr.Markdown("""
807
+ # **Rotation Notes:**
808
+ # - X-axis rotation: Tilt camera up/down
809
+ # - Y-axis rotation: Pan camera left/right
810
+ # - Z-axis rotation: Roll camera
811
+ # """)
812
 
813
+ # # Add rotation button in the Rotation tab
814
+ # cc_add_rot = gr.Button("Add Camera Rotation", variant="secondary")
815
 
816
+ # # Function to add rotation motion
817
+ # def add_rotation_motion(current_motion, rot_axis, rot_angle, rot_start, rot_end):
818
+ # # Format: rot axis angle [start_frame end_frame]
819
+ # frame_range = f" {int(rot_start)} {int(rot_end)}" if rot_start != 0 or rot_end != 48 else ""
820
+ # new_motion = f"rot {rot_axis} {rot_angle}{frame_range}"
821
 
822
+ # # Append to existing motion string with semicolon separator if needed
823
+ # if current_motion and current_motion.strip():
824
+ # updated_motion = f"{current_motion}; {new_motion}"
825
+ # else:
826
+ # updated_motion = new_motion
827
 
828
+ # return updated_motion
829
 
830
+ # # Connect rotation button
831
+ # cc_add_rot.click(
832
+ # fn=add_rotation_motion,
833
+ # inputs=[
834
+ # cc_camera_motion,
835
+ # cc_rot_axis, cc_rot_angle, cc_rot_start, cc_rot_end
836
+ # ],
837
+ # outputs=[cc_camera_motion]
838
+ # )
839
 
840
+ # # Add a clear button to reset the motion sequence
841
+ # cc_clear_motion = gr.Button("Clear All Motions", variant="stop")
842
 
843
+ # def clear_camera_motion():
844
+ # return ""
845
 
846
+ # cc_clear_motion.click(
847
+ # fn=clear_camera_motion,
848
+ # inputs=[],
849
+ # outputs=[cc_camera_motion]
850
+ # )
851
+
852
+ # cc_tracking_method = gr.Radio(
853
+ # label="Tracking Method",
854
+ # choices=["moge", "cotracker"],
855
+ # value="cotracker"
856
+ # )
857
 
858
+ # # Add run button for Camera Control tab
859
+ # cc_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
860
 
861
+ # # Connect to process function, but don't apply tracking
862
+ # cc_run_btn.click(
863
+ # fn=process_camera_control,
864
+ # inputs=[
865
+ # source_upload, common_prompt,
866
+ # cc_camera_motion, cc_tracking_method
867
+ # ],
868
+ # outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
869
+ # ).then(
870
+ # fn=enable_apply_button,
871
+ # inputs=[tracking_video],
872
+ # outputs=[apply_tracking_btn]
873
+ # )
874
 
875
+ # # Object Manipulation tab
876
+ # with gr.TabItem("Object Manipulation"):
877
+ # gr.Markdown("## Object Manipulation")
878
+ # om_object_mask = gr.File(
879
+ # label="Object Mask Image",
880
+ # file_types=["image"]
881
+ # )
882
+ # gr.Markdown("Upload a binary mask image, white areas indicate the object to manipulate")
883
+ # om_object_motion = gr.Dropdown(
884
+ # label="Object Motion Type",
885
+ # choices=["up", "down", "left", "right", "front", "back", "rot"],
886
+ # value="up"
887
+ # )
888
+ # om_tracking_method = gr.Radio(
889
+ # label="Tracking Method",
890
+ # choices=["moge", "cotracker"],
891
+ # value="cotracker"
892
+ # )
893
 
894
+ # # Add run button for Object Manipulation tab
895
+ # om_run_btn = gr.Button("Generate Tracking", variant="primary", size="lg")
896
 
897
+ # # Connect to process function, but don't apply tracking
898
+ # om_run_btn.click(
899
+ # fn=process_object_manipulation,
900
+ # inputs=[
901
+ # source_upload, common_prompt,
902
+ # om_object_motion, om_object_mask, om_tracking_method
903
+ # ],
904
+ # outputs=[tracking_video, video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
905
+ # ).then(
906
+ # fn=enable_apply_button,
907
+ # inputs=[tracking_video],
908
+ # outputs=[apply_tracking_btn]
909
+ # )
910
 
911
+ # # Animating meshes to video tab
912
+ # with gr.TabItem("Animating meshes to video"):
913
+ # gr.Markdown("## Mesh Animation to Video")
914
+ # gr.Markdown("""
915
+ # Note: Currently only supports tracking videos generated with Blender (version > 4.0).
916
+ # Please run the script `scripts/blender.py` in your Blender project to generate tracking videos.
917
+ # """)
918
+ # ma_tracking_video = gr.File(
919
+ # label="Tracking Video",
920
+ # file_types=["video"],
921
+ # # 添加 change 事件处理器,当上传文件时自动激活 Generate Video 按钮
922
+ # elem_id="ma_tracking_video"
923
+ # )
924
+ # gr.Markdown("Tracking video needs to be generated from Blender")
925
 
926
+ # # Simplified controls - Radio buttons for Yes/No and separate file upload
927
+ # with gr.Row():
928
+ # ma_repaint_option = gr.Radio(
929
+ # label="Repaint First Frame",
930
+ # choices=["No", "Yes"],
931
+ # value="No"
932
+ # )
933
+ # gr.Markdown("### Note: If you want to use your own image as repainted first frame, please upload the image in below.")
934
+ # # Custom image uploader (always visible)
935
+ # ma_repaint_image = gr.File(
936
+ # label="Custom Repaint Image",
937
+ # file_types=["image"]
938
+ # )
939
 
940
+ # # 修改按钮名称为 "Apply Repaint"
941
+ # ma_run_btn = gr.Button("Apply Repaint", variant="primary", size="lg")
942
 
943
+ # # 添加 tracking video 上传事件处理
944
+ # def handle_tracking_upload(file):
945
+ # if file is not None:
946
+ # tracking_path = save_uploaded_file(file)
947
+ # if tracking_path:
948
+ # return tracking_path, gr.update(interactive=True)
949
+ # return None, gr.update(interactive=False)
950
+
951
+ # # 当上传 tracking video 时,直接显示并激活 Generate Video 按钮
952
+ # ma_tracking_video.change(
953
+ # fn=handle_tracking_upload,
954
+ # inputs=[ma_tracking_video],
955
+ # outputs=[tracking_video, apply_tracking_btn]
956
+ # )
957
+
958
+ # # 修改 process_mesh_animation 函数的行为
959
+ # def process_mesh_animation_repaint(source, prompt, ma_repaint_option, ma_repaint_image):
960
+ # """只处理重绘部分,不处理跟踪视频"""
961
+ # try:
962
+ # # 保存上传的文件
963
+ # input_video_path = save_uploaded_file(source)
964
+ # if input_video_path is None:
965
+ # return None, None, None, None
966
+
967
+ # das = get_das_pipeline()
968
+ # video_tensor, fps, is_video = load_media(input_video_path)
969
+ # das.fps = fps
970
+
971
+ # repaint_img_tensor = None
972
+ # if ma_repaint_image is not None:
973
+ # repaint_path = save_uploaded_file(ma_repaint_image)
974
+ # repaint_img_tensor, _, _ = load_media(repaint_path)
975
+ # repaint_img_tensor = repaint_img_tensor[0]
976
+ # elif ma_repaint_option == "Yes":
977
+ # repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
978
+ # repaint_img_tensor = repainter.repaint(
979
+ # video_tensor[0],
980
+ # prompt=prompt,
981
+ # depth_path=None
982
+ # )
983
+
984
+ # # 返回处理结果,但不包括跟踪视频路径
985
+ # return video_tensor, None, repaint_img_tensor, fps
986
+ # except Exception as e:
987
+ # import traceback
988
+ # print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
989
+ # return None, None, None, None
990
+
991
+ # # 连接到修改后的处理函数
992
+ # ma_run_btn.click(
993
+ # fn=process_mesh_animation_repaint,
994
+ # inputs=[
995
+ # source_upload, common_prompt,
996
+ # ma_repaint_option, ma_repaint_image
997
+ # ],
998
+ # outputs=[video_tensor_state, tracking_tensor_state, repaint_img_tensor_state, fps_state]
999
+ # )
1000
+
1001
+ # 在所有 UI 元素定义之后,添加 Examples 组件
1002
+ examples_list = load_examples()
1003
+ if examples_list:
1004
+ with gr.Blocks() as examples_block:
1005
+ gr.Examples(
1006
+ examples=examples_list,
1007
+ inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
1008
+ outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
1009
+ fn=lambda *args: args, # 简单地返回输入作为输出
1010
+ cache_examples=True,
1011
+ label="Examples"
1012
+ )
1013
 
1014
  # Launch interface
1015
  if __name__ == "__main__":
 
1019
  print("Creating public link for remote access")
1020
 
1021
  # Launch interface
1022
+ demo.launch(share=args.share, server_port=args.port)