Yuliang commited on
Commit
e5f16e8
·
1 Parent(s): fed242e

query colors from RGB image

Browse files
Files changed (1) hide show
  1. apps/infer.py +33 -46
apps/infer.py CHANGED
@@ -33,6 +33,7 @@ from apps.Normal import Normal
33
  from apps.IFGeo import IFGeo
34
  from pytorch3d.ops import SubdivideMeshes
35
  from lib.common.config import cfg
 
36
  from lib.common.train_util import init_loss, load_normal_networks, load_networks
37
  from lib.common.BNI import BNI
38
  from lib.common.BNI_utils import save_normal_tensor
@@ -93,14 +94,14 @@ if __name__ == "__main__":
93
  "vol_res": cfg.vol_res,
94
  "single": args.multi,
95
  }
96
-
97
  if cfg.bni.use_ifnet:
98
  print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
99
  else:
100
  print(colored("Use SMPL-X (Explicit) for completion", "green"))
101
 
102
  dataset = TestDataset(dataset_param, device)
103
-
104
  print(colored(f"Dataset Size: {len(dataset)}", "green"))
105
 
106
  pbar = tqdm(dataset)
@@ -130,11 +131,7 @@ if __name__ == "__main__":
130
 
131
  os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
132
 
133
- in_tensor = {
134
- "smpl_faces": data["smpl_faces"],
135
- "image": data["img_icon"].to(device),
136
- "mask": data["img_mask"].to(device)
137
- }
138
 
139
  # The optimizer and variables
140
  optimed_pose = data["body_pose"].requires_grad_(True)
@@ -158,7 +155,7 @@ if __name__ == "__main__":
158
  N_body, N_pose = optimed_pose.shape[:2]
159
 
160
  smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
161
-
162
  if osp.exists(smpl_path):
163
 
164
  smpl_verts_lst = []
@@ -183,7 +180,7 @@ if __name__ == "__main__":
183
 
184
  in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
185
  in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
186
-
187
  else:
188
  # smpl optimization
189
  loop_smpl = tqdm(range(args.loop_smpl))
@@ -252,16 +249,14 @@ if __name__ == "__main__":
252
 
253
  # BUG: PyTorch3D silhouette renderer generates dilated mask
254
  bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
255
- smpl_arr_fake = torch.cat(
256
- [in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
257
- dim=-1)
258
 
259
  body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
260
  body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
261
  body_overlap_flag = body_overlap < cfg.body_overlap_thres
262
 
263
- losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] +
264
- diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
265
 
266
  losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
267
  occluded_idx = torch.where(body_overlap_flag)[0]
@@ -308,18 +303,15 @@ if __name__ == "__main__":
308
 
309
  img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
310
  torchvision.utils.save_image(
311
- torch.cat([
312
- data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5,
313
- (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5
314
- ],
315
- dim=3), img_crop_path)
316
 
317
  rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
318
  rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
319
 
320
  img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
321
- torchvision.utils.save_image(
322
- torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
323
 
324
  smpl_obj_lst = []
325
 
@@ -397,12 +389,7 @@ if __name__ == "__main__":
397
  )
398
 
399
  # BNI process
400
- BNI_object = BNI(
401
- dir_path=osp.join(args.out_dir, cfg.name, "BNI"),
402
- name=data["name"],
403
- BNI_dict=BNI_dict,
404
- cfg=cfg.bni,
405
- device=device)
406
 
407
  BNI_object.extract_surface(False)
408
 
@@ -419,16 +406,11 @@ if __name__ == "__main__":
419
  side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
420
 
421
  # mesh completion via IF-net
422
- in_tensor.update(
423
- dataset.depth_to_voxel({
424
- "depth_F": BNI_object.F_depth.unsqueeze(0),
425
- "depth_B": BNI_object.B_depth.unsqueeze(0)
426
- }))
427
-
428
- occupancies = VoxelGrid.from_mesh(
429
- side_mesh, cfg.vol_res, loc=[
430
- 0,
431
- ] * 3, scale=2.0).data.transpose(2, 1, 0)
432
  occupancies = np.flip(occupancies, axis=1)
433
 
434
  in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
@@ -446,10 +428,9 @@ if __name__ == "__main__":
446
  else:
447
  side_mesh = apply_vertex_mask(
448
  side_mesh,
449
- (SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
450
- SMPLX_object.eyeball_vertex_mask).eq(0).float(),
451
  )
452
-
453
  #register side_mesh to BNI surfaces
454
  side_mesh = Meshes(
455
  verts=[torch.tensor(side_mesh.vertices).float()],
@@ -458,7 +439,6 @@ if __name__ == "__main__":
458
  sm = SubdivideMeshes(side_mesh)
459
  side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
460
 
461
-
462
  side_verts = torch.tensor(side_mesh.vertices).float().to(device)
463
  side_faces = torch.tensor(side_mesh.faces).long().to(device)
464
 
@@ -469,7 +449,6 @@ if __name__ == "__main__":
469
 
470
  # export intermediate meshes
471
  BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
472
-
473
  full_lst = []
474
 
475
  if "face" in cfg.bni.use_smpl:
@@ -479,8 +458,7 @@ if __name__ == "__main__":
479
  face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
480
 
481
  # remove face neighbor triangles
482
- BNI_object.F_B_trimesh = part_removal(
483
- BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
484
  side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
485
  face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
486
  full_lst += [face_mesh]
@@ -497,8 +475,7 @@ if __name__ == "__main__":
497
  hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
498
 
499
  # remove hand neighbor triangles
500
- BNI_object.F_B_trimesh = part_removal(
501
- BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
502
  side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
503
  hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
504
  full_lst += [hand_mesh]
@@ -528,6 +505,16 @@ if __name__ == "__main__":
528
  rotate_recon_lst = dataset.render.get_image(cam_type="four")
529
  per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
530
 
 
 
 
 
 
 
 
 
 
 
531
  # for video rendering
532
  in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
533
  in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())
 
33
  from apps.IFGeo import IFGeo
34
  from pytorch3d.ops import SubdivideMeshes
35
  from lib.common.config import cfg
36
+ from lib.common.render import query_color
37
  from lib.common.train_util import init_loss, load_normal_networks, load_networks
38
  from lib.common.BNI import BNI
39
  from lib.common.BNI_utils import save_normal_tensor
 
94
  "vol_res": cfg.vol_res,
95
  "single": args.multi,
96
  }
97
+
98
  if cfg.bni.use_ifnet:
99
  print(colored("Use IF-Nets (Implicit)+ for completion", "green"))
100
  else:
101
  print(colored("Use SMPL-X (Explicit) for completion", "green"))
102
 
103
  dataset = TestDataset(dataset_param, device)
104
+
105
  print(colored(f"Dataset Size: {len(dataset)}", "green"))
106
 
107
  pbar = tqdm(dataset)
 
131
 
132
  os.makedirs(osp.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
133
 
134
+ in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["img_icon"].to(device), "mask": data["img_mask"].to(device)}
 
 
 
 
135
 
136
  # The optimizer and variables
137
  optimed_pose = data["body_pose"].requires_grad_(True)
 
155
  N_body, N_pose = optimed_pose.shape[:2]
156
 
157
  smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
158
+
159
  if osp.exists(smpl_path):
160
 
161
  smpl_verts_lst = []
 
180
 
181
  in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
182
  in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
183
+
184
  else:
185
  # smpl optimization
186
  loop_smpl = tqdm(range(args.loop_smpl))
 
249
 
250
  # BUG: PyTorch3D silhouette renderer generates dilated mask
251
  bg_value = in_tensor["T_normal_F"][0, 0, 0, 0]
252
+ smpl_arr_fake = torch.cat([in_tensor["T_normal_F"][:, 0].ne(bg_value).float(), in_tensor["T_normal_B"][:, 0].ne(bg_value).float()],
253
+ dim=-1)
 
254
 
255
  body_overlap = (gt_arr * smpl_arr_fake.gt(0.0)).sum(dim=[1, 2]) / smpl_arr_fake.gt(0.0).sum(dim=[1, 2])
256
  body_overlap_mask = (gt_arr * smpl_arr_fake).unsqueeze(1)
257
  body_overlap_flag = body_overlap < cfg.body_overlap_thres
258
 
259
+ losses["normal"]["value"] = (diff_F_smpl * body_overlap_mask[..., :512] + diff_B_smpl * body_overlap_mask[..., 512:]).mean() / 2.0
 
260
 
261
  losses["silhouette"]["weight"] = [0 if flag else 1.0 for flag in body_overlap_flag]
262
  occluded_idx = torch.where(body_overlap_flag)[0]
 
303
 
304
  img_crop_path = osp.join(args.out_dir, cfg.name, "png", f"{data['name']}_crop.png")
305
  torchvision.utils.save_image(
306
+ torch.cat(
307
+ [data["img_crop"][:, :3], (in_tensor['normal_F'].detach().cpu() + 1.0) * 0.5, (in_tensor['normal_B'].detach().cpu() + 1.0) * 0.5],
308
+ dim=3), img_crop_path)
 
 
309
 
310
  rgb_norm_F = blend_rgb_norm(in_tensor["normal_F"], data)
311
  rgb_norm_B = blend_rgb_norm(in_tensor["normal_B"], data)
312
 
313
  img_overlap_path = osp.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png")
314
+ torchvision.utils.save_image(torch.Tensor([data["img_raw"], rgb_norm_F, rgb_norm_B]).permute(0, 3, 1, 2) / 255., img_overlap_path)
 
315
 
316
  smpl_obj_lst = []
317
 
 
389
  )
390
 
391
  # BNI process
392
+ BNI_object = BNI(dir_path=osp.join(args.out_dir, cfg.name, "BNI"), name=data["name"], BNI_dict=BNI_dict, cfg=cfg.bni, device=device)
 
 
 
 
 
393
 
394
  BNI_object.extract_surface(False)
395
 
 
406
  side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
407
 
408
  # mesh completion via IF-net
409
+ in_tensor.update(dataset.depth_to_voxel({"depth_F": BNI_object.F_depth.unsqueeze(0), "depth_B": BNI_object.B_depth.unsqueeze(0)}))
410
+
411
+ occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
412
+ 0,
413
+ ] * 3, scale=2.0).data.transpose(2, 1, 0)
 
 
 
 
 
414
  occupancies = np.flip(occupancies, axis=1)
415
 
416
  in_tensor["body_voxels"] = torch.tensor(occupancies.copy()).float().unsqueeze(0).to(device)
 
428
  else:
429
  side_mesh = apply_vertex_mask(
430
  side_mesh,
431
+ (SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask + SMPLX_object.eyeball_vertex_mask).eq(0).float(),
 
432
  )
433
+
434
  #register side_mesh to BNI surfaces
435
  side_mesh = Meshes(
436
  verts=[torch.tensor(side_mesh.vertices).float()],
 
439
  sm = SubdivideMeshes(side_mesh)
440
  side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
441
 
 
442
  side_verts = torch.tensor(side_mesh.vertices).float().to(device)
443
  side_faces = torch.tensor(side_mesh.faces).long().to(device)
444
 
 
449
 
450
  # export intermediate meshes
451
  BNI_object.F_B_trimesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_BNI.obj")
 
452
  full_lst = []
453
 
454
  if "face" in cfg.bni.use_smpl:
 
458
  face_mesh.vertices = face_mesh.vertices - np.array([0, 0, cfg.bni.thickness])
459
 
460
  # remove face neighbor triangles
461
+ BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
 
462
  side_mesh = part_removal(side_mesh, face_mesh, cfg.bni.face_thres, device, smplx_mesh, region="face")
463
  face_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_face.obj")
464
  full_lst += [face_mesh]
 
475
  hand_mesh = apply_vertex_mask(hand_mesh, hand_mask)
476
 
477
  # remove hand neighbor triangles
478
+ BNI_object.F_B_trimesh = part_removal(BNI_object.F_B_trimesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
 
479
  side_mesh = part_removal(side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand")
480
  hand_mesh.export(f"{args.out_dir}/{cfg.name}/obj/{data['name']}_{idx}_hand.obj")
481
  full_lst += [hand_mesh]
 
505
  rotate_recon_lst = dataset.render.get_image(cam_type="four")
506
  per_loop_lst.extend([in_tensor['image'][idx:idx + 1]] + rotate_recon_lst)
507
 
508
+ # coloring the final mesh
509
+ final_colors = query_color(
510
+ torch.tensor(final_mesh.vertices).float(),
511
+ torch.tensor(final_mesh.faces).long(),
512
+ in_tensor["image"][idx:idx + 1],
513
+ device=device,
514
+ )
515
+ final_mesh.visual.vertex_colors = final_colors
516
+ final_mesh.export(final_path)
517
+
518
  # for video rendering
519
  in_tensor["BNI_verts"].append(torch.tensor(final_mesh.vertices).float())
520
  in_tensor["BNI_faces"].append(torch.tensor(final_mesh.faces).long())