Spaces:
Running
on
L40S
Running
on
L40S
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: ps-license@tuebingen.mpg.de | |
from lib.dataset.mesh_util import projection | |
from lib.common.render import Render | |
import numpy as np | |
import torch | |
import os.path as osp | |
from torchvision.utils import make_grid | |
from pytorch3d.io import IO | |
from pytorch3d.ops import sample_points_from_meshes | |
from pytorch3d.loss.point_mesh_distance import _PointFaceDistance | |
from pytorch3d.structures import Pointclouds | |
from PIL import Image | |
def point_mesh_distance(meshes, pcls): | |
if len(meshes) != len(pcls): | |
raise ValueError("meshes and pointclouds must be equal sized batches") | |
N = len(meshes) | |
# packed representation for pointclouds | |
points = pcls.points_packed() # (P, 3) | |
points_first_idx = pcls.cloud_to_packed_first_idx() | |
max_points = pcls.num_points_per_cloud().max().item() | |
# packed representation for faces | |
verts_packed = meshes.verts_packed() | |
faces_packed = meshes.faces_packed() | |
tris = verts_packed[faces_packed] # (T, 3, 3) | |
tris_first_idx = meshes.mesh_to_faces_packed_first_idx() | |
# point to face distance: shape (P,) | |
point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, | |
tris_first_idx, max_points, 5e-3) | |
# weight each example by the inverse of number of points in the example | |
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) | |
num_points_per_cloud = pcls.num_points_per_cloud() # (N,) | |
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) | |
weights_p = 1.0 / weights_p.float() | |
point_to_face = torch.sqrt(point_to_face) * weights_p | |
point_dist = point_to_face.sum() / N | |
return point_dist | |
class Evaluator: | |
def __init__(self, device): | |
self.render = Render(size=512, device=device) | |
self.device = device | |
def set_mesh(self, result_dict): | |
for k, v in result_dict.items(): | |
setattr(self, k, v) | |
self.verts_pr -= self.recon_size / 2.0 | |
self.verts_pr /= self.recon_size / 2.0 | |
self.verts_gt = projection(self.verts_gt, self.calib) | |
self.verts_gt[:, 1] *= -1 | |
self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) | |
self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) | |
def calculate_normal_consist(self, normal_path): | |
self.render.meshes = self.src_mesh | |
src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], | |
bg='black') | |
self.render.meshes = self.tgt_mesh | |
tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], | |
bg='black') | |
src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
src_norm[src_norm == 0.0] = 1.0 | |
tgt_norm[tgt_norm == 0.0] = 1.0 | |
src_normal_arr /= src_norm | |
tgt_normal_arr /= tgt_norm | |
src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
error = (( | |
(src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 | |
#print('normal error:', error) | |
normal_img = Image.fromarray( | |
(torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( | |
1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) | |
normal_img.save(normal_path) | |
error_list = [] | |
if len(src_normal_imgs) > 4: | |
for i in range(len(src_normal_imgs)): | |
src_normal_arr = src_normal_imgs[i] # Get each source normal image | |
tgt_normal_arr = tgt_normal_imgs[i] # Get corresponding target normal image | |
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
src_norm[src_norm == 0.0] = 1.0 | |
tgt_norm[tgt_norm == 0.0] = 1.0 | |
src_normal_arr /= src_norm | |
tgt_normal_arr /= tgt_norm | |
src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 | |
error_list.append(error) | |
return error_list | |
else: | |
src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
src_norm[src_norm == 0.0] = 1.0 | |
tgt_norm[tgt_norm == 0.0] = 1.0 | |
src_normal_arr /= src_norm | |
tgt_normal_arr /= tgt_norm | |
# sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) | |
src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
error = (( | |
(src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 | |
#print('normal error:', error) | |
return error | |
def export_mesh(self, dir, name): | |
IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) | |
IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) | |
def calculate_chamfer_p2s(self, num_samples=1000): | |
tgt_points = Pointclouds( | |
sample_points_from_meshes(self.tgt_mesh, num_samples)) | |
src_points = Pointclouds( | |
sample_points_from_meshes(self.src_mesh, num_samples)) | |
p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 | |
chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 | |
+ p2s_dist) * 0.5 | |
return chamfer_dist, p2s_dist | |
def calc_acc(self, output, target, thres=0.5, use_sdf=False): | |
# # remove the surface points with thres | |
# non_surf_ids = (target != thres) | |
# output = output[non_surf_ids] | |
# target = target[non_surf_ids] | |
with torch.no_grad(): | |
output = output.masked_fill(output < thres, 0.0) | |
output = output.masked_fill(output > thres, 1.0) | |
if use_sdf: | |
target = target.masked_fill(target < thres, 0.0) | |
target = target.masked_fill(target > thres, 1.0) | |
acc = output.eq(target).float().mean() | |
# iou, precison, recall | |
output = output > thres | |
target = target > thres | |
union = output | target | |
inter = output & target | |
_max = torch.tensor(1.0).to(output.device) | |
union = max(union.sum().float(), _max) | |
true_pos = max(inter.sum().float(), _max) | |
vol_pred = max(output.sum().float(), _max) | |
vol_gt = max(target.sum().float(), _max) | |
return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt | |