EXCAI commited on
Commit
bb57964
·
verified ·
1 Parent(s): 23fa0c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -206
app.py CHANGED
@@ -262,199 +262,7 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
262
  import traceback
263
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
264
  return None, None, None, None, None
265
-
266
- def process_camera_control(source, prompt, camera_motion, tracking_method):
267
- """Process camera control task"""
268
- try:
269
- # 保存上传的文件
270
- input_media_path = save_uploaded_file(source)
271
- if input_media_path is None:
272
- return None, None, None
273
-
274
- print(f"DEBUG: Camera motion: '{camera_motion}'")
275
- print(f"DEBUG: Tracking method: '{tracking_method}'")
276
-
277
- das = get_das_pipeline()
278
- video_tensor, fps, is_video = load_media(input_media_path)
279
- das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
280
-
281
- if not is_video:
282
- tracking_method = "moge"
283
- print("Image input detected, switching to MoGe")
284
-
285
- cam_motion = CameraMotionGenerator(camera_motion)
286
- repaint_img_tensor = None
287
- tracking_tensor = None
288
-
289
- if tracking_method == "moge":
290
- moge = get_moge_model()
291
-
292
- infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
293
- H, W = infer_result["points"].shape[0:2]
294
- pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
295
- cam_motion.set_intr(infer_result["intrinsics"])
296
-
297
- if camera_motion:
298
- poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
299
- print("Camera motion applied")
300
- else:
301
- poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
302
-
303
- pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
304
- pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
305
-
306
- _, tracking_tensor = das.visualize_tracking_moge(
307
- pred_tracks.cpu().numpy(),
308
- infer_result["mask"].cpu().numpy()
309
- )
310
- print('Export tracking video via MoGe')
311
- else:
312
- # 使用在CPU上运行的cotracker
313
- pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
314
-
315
- # 使用封装的 VGGT 处理函数
316
- extr, intr = process_vggt(video_tensor)
317
-
318
- cam_motion.set_intr(intr)
319
- cam_motion.set_extr(extr)
320
-
321
- if camera_motion:
322
- poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
323
- pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
324
- pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
325
- print("Camera motion applied")
326
-
327
- tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
328
- print('Export tracking video via cotracker')
329
-
330
- # 返回处理结果,但不应用跟踪
331
- return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
332
- except Exception as e:
333
- import traceback
334
- print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
335
- return None, None, None, None, None
336
-
337
- def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
338
- """Process object manipulation task"""
339
- try:
340
- # Save uploaded files
341
- input_image_path = save_uploaded_file(source)
342
- if input_image_path is None:
343
- return None, None, None, None, None
344
-
345
- object_mask_path = save_uploaded_file(object_mask)
346
- if object_mask_path is None:
347
- print("Object mask not provided")
348
- return None, None, None, None, None
349
 
350
- das = get_das_pipeline()
351
- video_tensor, fps, is_video = load_media(input_image_path)
352
- das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
353
-
354
- if not is_video:
355
- tracking_method = "moge"
356
- print("Image input detected, switching to MoGe")
357
-
358
- mask_image = Image.open(object_mask_path).convert('L')
359
- mask_image = transforms.Resize((480, 720))(mask_image)
360
- mask = torch.from_numpy(np.array(mask_image) > 127)
361
-
362
- motion_generator = ObjectMotionGenerator(device=das.device)
363
- repaint_img_tensor = None
364
- tracking_tensor = None
365
-
366
- if tracking_method == "moge":
367
- moge = get_moge_model()
368
-
369
- infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
370
- H, W = infer_result["points"].shape[0:2]
371
- pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
372
-
373
- pred_tracks = motion_generator.apply_motion(
374
- pred_tracks=pred_tracks,
375
- mask=mask,
376
- motion_type=object_motion,
377
- distance=50,
378
- num_frames=49,
379
- tracking_method="moge"
380
- )
381
- print(f"Object motion '{object_motion}' applied using provided mask")
382
- poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
383
- pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
384
-
385
- cam_motion = CameraMotionGenerator(None)
386
- cam_motion.set_intr(infer_result["intrinsics"])
387
- pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
388
-
389
- _, tracking_tensor = das.visualize_tracking_moge(
390
- pred_tracks.cpu().numpy(),
391
- infer_result["mask"].cpu().numpy()
392
- )
393
- print('Export tracking video via MoGe')
394
- else:
395
- # 使用在CPU上运行的cotracker
396
- pred_tracks, pred_visibility = generate_tracking_cotracker(video_tensor)
397
-
398
- # 使用封装的 VGGT 处理函数
399
- extr, intr = process_vggt(video_tensor)
400
-
401
- pred_tracks = motion_generator.apply_motion(
402
- pred_tracks=pred_tracks.squeeze(),
403
- mask=mask,
404
- motion_type=object_motion,
405
- distance=50,
406
- num_frames=49,
407
- tracking_method="cotracker"
408
- )
409
- print(f"Object motion '{object_motion}' applied using provided mask")
410
-
411
- tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), pred_visibility)
412
- print('Export tracking video via cotracker')
413
-
414
- # 返回处理结果,但不应用跟踪
415
- return tracking_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
416
- except Exception as e:
417
- import traceback
418
- print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
419
- return None, None, None, None, None
420
-
421
- def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
422
- """Process mesh animation task"""
423
- try:
424
- # Save uploaded files
425
- input_video_path = save_uploaded_file(source)
426
- if input_video_path is None:
427
- return None, None, None, None, None
428
-
429
- tracking_video_path = save_uploaded_file(tracking_video)
430
- if tracking_video_path is None:
431
- return None, None, None, None, None
432
-
433
- das = get_das_pipeline()
434
- video_tensor, fps, is_video = load_media(input_video_path)
435
- das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
436
-
437
- tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
438
- repaint_img_tensor = None
439
- if ma_repaint_image is not None:
440
- repaint_path = save_uploaded_file(ma_repaint_image)
441
- repaint_img_tensor, _, _ = load_media(repaint_path)
442
- repaint_img_tensor = repaint_img_tensor[0] # 获取第一帧
443
- elif ma_repaint_option == "Yes":
444
- repainter = FirstFrameRepainter(gpu_id=GPU_ID, output_dir=OUTPUT_DIR)
445
- repaint_img_tensor = repainter.repaint(
446
- video_tensor[0],
447
- prompt=prompt,
448
- depth_path=None
449
- )
450
-
451
- # 直接返回上传的跟踪视频路径,而不是生成新的跟踪视频
452
- return tracking_video_path, video_tensor, tracking_tensor, repaint_img_tensor, fps
453
- except Exception as e:
454
- import traceback
455
- print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
456
- return None, None, None, None, None
457
-
458
  def generate_tracking_cotracker(video_tensor, density=30):
459
  """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
460
 
@@ -674,18 +482,6 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
674
  apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False)
675
  output_video = gr.Video(label="Generated Video")
676
 
677
- examples_list = load_examples()
678
- if examples_list:
679
- with gr.Blocks() as examples_block:
680
- gr.Examples(
681
- examples=examples_list,
682
- inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
683
- outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
684
- fn=lambda *args: args,
685
- cache_examples=True,
686
- label="Examples"
687
- )
688
-
689
  with left_column:
690
  source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"])
691
  gr.Markdown("Upload a video or image, We will extract the motion and space structure from it")
@@ -749,10 +545,22 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
749
  )
750
 
751
  with gr.TabItem("Camera Control"):
752
- gr.Markdown("Camera Control is not available in Huggingface Space, please deploy our GitHub project on your own machine")
753
 
754
  with gr.TabItem("Object Manipulation"):
755
- gr.Markdown("Object Manipulation is not available in Huggingface Space, please deploy our GitHub project on your own machine")
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
 
758
  # Launch interface
 
262
  import traceback
263
  print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
264
  return None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  def generate_tracking_cotracker(video_tensor, density=30):
267
  """在CPU上生成跟踪视频,只使用第一帧的深度信息,使用矩阵运算提高效率
268
 
 
482
  apply_tracking_btn = gr.Button("Generate Video", variant="primary", size="lg", interactive=False)
483
  output_video = gr.Video(label="Generated Video")
484
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  with left_column:
486
  source_upload = gr.UploadButton("1. Upload Source", file_types=["image", "video"])
487
  gr.Markdown("Upload a video or image, We will extract the motion and space structure from it")
 
545
  )
546
 
547
  with gr.TabItem("Camera Control"):
548
+ gr.Markdown("Camera Control is not available in Huggingface Space, please deploy our [GitHub project](https://github.com/IGL-HKUST/DiffusionAsShader) on your own machine")
549
 
550
  with gr.TabItem("Object Manipulation"):
551
+ gr.Markdown("Object Manipulation is not available in Huggingface Space, please deploy our [GitHub project](https://github.com/IGL-HKUST/DiffusionAsShader) on your own machine")
552
+
553
+ examples_list = load_examples()
554
+ if examples_list:
555
+ with gr.Blocks() as examples_block:
556
+ gr.Examples(
557
+ examples=examples_list,
558
+ inputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
559
+ outputs=[source_preview, mt_repaint_preview, common_prompt, tracking_video, output_video],
560
+ fn=lambda *args: args,
561
+ cache_examples=True,
562
+ label="Examples"
563
+ )
564
 
565
 
566
  # Launch interface