ECON / apps /benchmark.py
Yuliang's picture
upgrade to Gradio 4.14.0
e0ba903
raw
history blame
11.5 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("trimesh").setLevel(logging.ERROR)
import argparse
import os
import torch
from termcolor import colored
from tqdm.auto import tqdm
from apps.IFGeo import IFGeo
from apps.Normal import Normal
from lib.common.BNI import BNI
from lib.common.BNI_utils import save_normal_tensor
from lib.common.config import cfg
from lib.common.voxelize import VoxelGrid
from lib.dataset.EvalDataset import EvalDataset
from lib.dataset.Evaluator import Evaluator
from lib.dataset.mesh_util import *
torch.backends.cudnn.benchmark = True
speed_analysis = False
if __name__ == "__main__":
if speed_analysis:
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
# loading cfg file
parser = argparse.ArgumentParser()
parser.add_argument("-gpu", "--gpu_device", type=int, default=0)
parser.add_argument("-ifnet", action="store_true")
parser.add_argument("-cfg", "--config", type=str, default="./configs/econ.yaml")
args = parser.parse_args()
# cfg read and merge
cfg.merge_from_file(args.config)
device = torch.device("cuda:0")
cfg_test_list = [
"dataset.rotation_num",
3,
"bni.use_smpl",
["hand"],
"bni.use_ifnet",
args.ifnet,
"bni.cut_intersection",
True,
]
# # if w/ RenderPeople+CAPE
# cfg_test_list += ["dataset.types", ["cape", "renderpeople"], "dataset.scales", [100.0, 1.0]]
# if only w/ CAPE
cfg_test_list += ["dataset.types", ["cape"], "dataset.scales", [100.0]]
cfg.merge_from_list(cfg_test_list)
cfg.freeze()
# load normal model
normal_net = Normal.load_from_checkpoint(
cfg=cfg, checkpoint_path=cfg.normal_path, map_location=device, strict=False
)
normal_net = normal_net.to(device)
normal_net.netG.eval()
print(
colored(
f"Resume Normal Estimator from: {cfg.normal_path}", "green"
)
)
# SMPLX object
SMPLX_object = SMPLX()
dataset = EvalDataset(cfg=cfg, device=device)
evaluator = Evaluator(device=device)
export_dir = osp.join(cfg.results_path, cfg.name, "IF-Net+" if cfg.bni.use_ifnet else "SMPL-X")
print(colored(f"Dataset Size: {len(dataset)}", "green"))
if cfg.bni.use_ifnet:
# load IFGeo model
ifnet = IFGeo.load_from_checkpoint(
cfg=cfg, checkpoint_path=cfg.ifnet_path, map_location=device, strict=False
)
ifnet = ifnet.to(device)
ifnet.netG.eval()
print(colored(f"Resume IF-Net+ from {Format.start} {cfg.ifnet_path} {Format.end}", "green"))
print(colored(f"Complete with {Format.start} IF-Nets+ (Implicit) {Format.end}", "green"))
else:
print(colored(f"Complete with {Format.start} SMPL-X (Explicit) {Format.end}", "green"))
pbar = tqdm(dataset)
benchmark = {}
for data in pbar:
for key in data.keys():
if torch.is_tensor(data[key]):
data[key] = data[key].unsqueeze(0).to(device)
is_smplx = True if 'smplx_path' in data.keys() else False
# filenames and makedirs
current_name = f"{data['dataset']}-{data['subject']}-{data['rotation']:03d}"
current_dir = osp.join(export_dir, data['dataset'], data['subject'])
os.makedirs(current_dir, exist_ok=True)
final_path = osp.join(current_dir, f"{current_name}_final.obj")
if not osp.exists(final_path):
in_tensor = data.copy()
batch_smpl_verts = in_tensor["smpl_verts"].detach()
batch_smpl_verts *= torch.tensor([1.0, -1.0, 1.0]).to(device)
batch_smpl_faces = in_tensor["smpl_faces"].detach()
in_tensor["depth_F"], in_tensor["depth_B"] = dataset.render_depth(
batch_smpl_verts, batch_smpl_faces
)
with torch.no_grad():
in_tensor["normal_F"], in_tensor["normal_B"] = normal_net.netG(in_tensor)
smpl_mesh = trimesh.Trimesh(
batch_smpl_verts.cpu().numpy()[0],
batch_smpl_faces.cpu().numpy()[0]
)
side_mesh = smpl_mesh.copy()
face_mesh = smpl_mesh.copy()
hand_mesh = smpl_mesh.copy()
smplx_mesh = smpl_mesh.copy()
# save normals, depths and masks
BNI_dict = save_normal_tensor(
in_tensor,
0,
osp.join(current_dir, "BNI/param_dict"),
cfg.bni.thickness if data['dataset'] == 'renderpeople' else 0.0,
)
# BNI process
BNI_object = BNI(
dir_path=osp.join(current_dir, "BNI"),
name=current_name,
BNI_dict=BNI_dict,
cfg=cfg.bni,
device=device
)
BNI_object.extract_surface(False)
if is_smplx:
side_mesh = apply_face_mask(side_mesh, ~SMPLX_object.smplx_eyeball_fid_mask)
if cfg.bni.use_ifnet:
# mesh completion via IF-net
in_tensor.update(
dataset.depth_to_voxel({
"depth_F": BNI_object.F_depth.unsqueeze(0).to(device), "depth_B":
BNI_object.B_depth.unsqueeze(0).to(device)
})
)
occupancies = VoxelGrid.from_mesh(side_mesh, cfg.vol_res, loc=[
0,
] * 3, scale=2.0).data.transpose(2, 1, 0)
occupancies = np.flip(occupancies, axis=1)
in_tensor["body_voxels"] = torch.tensor(occupancies.copy()
).float().unsqueeze(0).to(device)
with torch.no_grad():
sdf = ifnet.reconEngine(netG=ifnet.netG, batch=in_tensor)
verts_IF, faces_IF = ifnet.reconEngine.export_mesh(sdf)
if ifnet.clean_mesh_flag:
verts_IF, faces_IF = clean_mesh(verts_IF, faces_IF)
side_mesh_path = osp.join(current_dir, f"{current_name}_IF.obj")
side_mesh = remesh_laplacian(trimesh.Trimesh(verts_IF, faces_IF), side_mesh_path)
full_lst = []
if "hand" in cfg.bni.use_smpl:
# only hands
if is_smplx:
hand_mesh = apply_vertex_mask(hand_mesh, SMPLX_object.smplx_mano_vertex_mask)
else:
hand_mesh = apply_vertex_mask(hand_mesh, SMPLX_object.smpl_mano_vertex_mask)
# remove hand neighbor triangles
BNI_object.F_B_trimesh = part_removal(
BNI_object.F_B_trimesh,
hand_mesh,
cfg.bni.hand_thres,
device,
smplx_mesh,
region="hand"
)
side_mesh = part_removal(
side_mesh, hand_mesh, cfg.bni.hand_thres, device, smplx_mesh, region="hand"
)
# hand_mesh.export(osp.join(current_dir, f"{current_name}_hands.obj"))
full_lst += [hand_mesh]
full_lst += [BNI_object.F_B_trimesh]
# initial side_mesh could be SMPLX or IF-net
side_mesh = part_removal(
side_mesh, sum(full_lst), 2e-2, device, smplx_mesh, region="", clean=False
)
full_lst += [side_mesh]
if cfg.bni.use_poisson:
final_mesh = poisson(
sum(full_lst),
final_path,
cfg.bni.poisson_depth,
)
else:
final_mesh = sum(full_lst)
final_mesh.export(final_path)
else:
final_mesh = trimesh.load(final_path)
# evaluation
metric_path = osp.join(export_dir, "metric.npy")
if osp.exists(metric_path):
benchmark = np.load(metric_path, allow_pickle=True).item()
if benchmark == {} or data["dataset"] not in benchmark.keys(
) or f"{data['subject']}-{data['rotation']}" not in benchmark[data["dataset"]]["subject"]:
result_eval = {
"verts_gt": data["verts"][0],
"faces_gt": data["faces"][0],
"verts_pr": final_mesh.vertices,
"faces_pr": final_mesh.faces,
"calib": data["calib"][0],
}
evaluator.set_mesh(result_eval, scale=False)
chamfer, p2s = evaluator.calculate_chamfer_p2s(num_samples=1000)
nc = evaluator.calculate_normal_consist(osp.join(current_dir, f"{current_name}_nc.png"))
if data["dataset"] not in benchmark.keys():
benchmark[data["dataset"]] = {
"chamfer": [chamfer.item()],
"p2s": [p2s.item()],
"nc": [nc.item()],
"subject": [f"{data['subject']}-{data['rotation']}"],
"total": 1,
}
else:
benchmark[data["dataset"]]["chamfer"] += [chamfer.item()]
benchmark[data["dataset"]]["p2s"] += [p2s.item()]
benchmark[data["dataset"]]["nc"] += [nc.item()]
benchmark[data["dataset"]]["subject"] += [f"{data['subject']}-{data['rotation']}"]
benchmark[data["dataset"]]["total"] += 1
np.save(metric_path, benchmark, allow_pickle=True)
else:
subject_idx = benchmark[data["dataset"]
]["subject"].index(f"{data['subject']}-{data['rotation']}")
chamfer = torch.tensor(benchmark[data["dataset"]]["chamfer"][subject_idx])
p2s = torch.tensor(benchmark[data["dataset"]]["p2s"][subject_idx])
nc = torch.tensor(benchmark[data["dataset"]]["nc"][subject_idx])
pbar.set_description(
f"{current_name} | {chamfer.item():.3f} | {p2s.item():.3f} | {nc.item():.4f}"
)
for dataset in benchmark.keys():
for metric in ["chamfer", "p2s", "nc"]:
print(
f"{dataset}-{metric}: {sum(benchmark[dataset][metric])/benchmark[dataset]['total']:.4f}"
)
if cfg.bni.use_ifnet:
print(colored("Finish evaluating on ECON_IF", "green"))
else:
print(colored("Finish evaluating of ECON_EX", "green"))
if speed_analysis:
profiler.disable()
profiler.dump_stats(osp.join(export_dir, "econ.stats"))
stats = pstats.Stats(osp.join(export_dir, "econ.stats"))
stats.sort_stats("cumtime").print_stats(10)