PSHuman / lib /dataset /PointFeat.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
from pytorch3d.structures import Meshes, Pointclouds
import torch.nn.functional as F
import torch
from lib.common.render_utils import face_vertices
from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection
from kaolin.ops.mesh import check_sign, face_normals
from kaolin.metrics.trianglemesh import point_to_mesh_distance
from lib.dataset.Evaluator import point_mesh_distance
from lib.dataset.ECON_Evaluator import econ_point_mesh_distance
def distance_matrix(x, y=None, p = 2): #pairwise distance of vectors
y = x if type(y) == type(None) else y
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
dist = torch.norm(x - y, dim=-1) if torch.__version__ >= '1.7.0' else torch.pow(x - y, p).sum(2)**(1/p)
return dist
class NN():
def __init__(self, X = None, Y = None, p = 2):
self.p = p
self.train(X, Y)
def train(self, X, Y):
self.train_pts = X
self.train_label = Y
def __call__(self, x):
return self.predict(x)
def predict(self, x):
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
name = self.__class__.__name__
raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first")
dist=[]
chunk=10000
for i in range(0,x.shape[0],chunk):
dist.append(distance_matrix(x[i:i+chunk], self.train_pts, self.p))
dist = torch.cat(dist, dim=0)
labels = torch.argmin(dist, dim=1)
return self.train_label[labels],labels
class PointFeat:
def __init__(self, verts, faces):
# verts [B, N_vert, 3]
# faces [B, N_face, 3]
# triangles [B, N_face, 3, 3]
self.Bsize = verts.shape[0]
self.mesh = Meshes(verts, faces)
self.device = verts.device
self.faces = faces
# SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth
# 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929
# 2. fill mouth holes with 30 more faces
if verts.shape[1] == 10475:
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
mouth_faces = (torch.as_tensor(
SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(
self.Bsize, 1, 1).to(self.device))
self.faces = torch.cat([faces, mouth_faces], dim=1).long()
self.verts = verts
self.triangles = face_vertices(self.verts, self.faces)
def get_face_normals(self):
return face_normals(self.verts, self.faces)
def get_nearest_point(self,points):
# points [1, N, 3]
# find nearest point on mesh
#devices = points.device
points=points.squeeze(0)
nn_class=NN(X=self.verts.squeeze(0),Y=self.verts.squeeze(0),p=2)
nearest_points,nearest_points_ind=nn_class.predict(points)
# closest_triangles = torch.gather(
# self.triangles, 1,
# pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
# bary_weights = barycentric_coordinates_of_projection(
# points.view(-1, 3), closest_triangles)
# bary_weights=F.normalize(bary_weights, p=2, dim=1)
# normals = face_normals(self.triangles)
# # make the lenght of the normal is 1
# normals = F.normalize(normals, p=2, dim=2)
# # get the normal of the closest triangle
# closest_normals = torch.gather(
# normals, 1,
# pts_ind[:, :, None].expand(-1, -1, 3)).view(-1, 3)
return nearest_points,nearest_points_ind # on cpu
def query_barycentirc_feats(self,points,feats):
# feats [B,N,C]
residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles)
closest_triangles = torch.gather(
self.triangles, 1,
pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
bary_weights = barycentric_coordinates_of_projection(
points.view(-1, 3), closest_triangles)
feat_arr=feats
feat_dim = feat_arr.shape[-1]
feat_tri = face_vertices(feat_arr, self.faces)
closest_feats = torch.gather( # query点距离最近的face的三个点的feature
feat_tri, 1,
pts_ind[:, :, None,
None].expand(-1, -1, 3,
feat_dim)).view(-1, 3, feat_dim)
pts_feats = ((closest_feats *
bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和
return pts_feats.view(self.Bsize,-1,feat_dim)
def query(self, points, feats={}):
# points [B, N, 3]
# feats {'feat_name': [B, N, C]}
del_keys = ["smpl_verts", "smpl_faces", "smpl_joint","smpl_sample_id"]
residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles)
closest_triangles = torch.gather(
self.triangles, 1,
pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3)
bary_weights = barycentric_coordinates_of_projection(
points.view(-1, 3), closest_triangles)
out_dict = {}
for feat_key in feats.keys():
if feat_key in del_keys:
continue
elif feats[feat_key] is not None:
feat_arr = feats[feat_key]
feat_dim = feat_arr.shape[-1]
feat_tri = face_vertices(feat_arr, self.faces)
closest_feats = torch.gather( # query点距离最近的face的三个点的feature
feat_tri, 1,
pts_ind[:, :, None,
None].expand(-1, -1, 3,
feat_dim)).view(-1, 3, feat_dim)
pts_feats = ((closest_feats *
bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和
out_dict[feat_key.split("_")[1]] = pts_feats
else:
out_dict[feat_key.split("_")[1]] = None
if "sdf" in out_dict.keys():
pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3))
pts_signs = 2.0 * (
check_sign(self.verts, self.faces[0], points).float() - 0.5)
pts_sdf = (pts_dist * pts_signs).unsqueeze(-1)
out_dict["sdf"] = pts_sdf
if "vis" in out_dict.keys():
out_dict["vis"] = out_dict["vis"].ge(1e-1).float()
if "norm" in out_dict.keys():
pts_norm = out_dict["norm"] * torch.tensor([-1.0, 1.0, -1.0]).to(
self.device)
out_dict["norm"] = F.normalize(pts_norm, dim=2)
if "cmap" in out_dict.keys():
out_dict["cmap"] = out_dict["cmap"].clamp_(min=0.0, max=1.0)
for out_key in out_dict.keys():
out_dict[out_key] = out_dict[out_key].view(
self.Bsize, -1, out_dict[out_key].shape[-1])
return out_dict
class ECON_PointFeat:
def __init__(self, verts, faces):
# verts [B, N_vert, 3]
# faces [B, N_face, 3]
# triangles [B, N_face, 3, 3]
self.Bsize = verts.shape[0]
self.device = verts.device
self.faces = faces
# SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth
# 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929
# 2. fill mouth holes with 30 more faces
if verts.shape[1] == 10475:
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
mouth_faces = (
torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1,
1).to(self.device)
)
self.faces = torch.cat([faces, mouth_faces], dim=1).long()
self.verts = verts.float()
self.triangles = face_vertices(self.verts, self.faces)
self.mesh = Meshes(self.verts, self.faces).to(self.device)
def query(self, points):
points = points.float()
residues, pts_ind = econ_point_mesh_distance(self.mesh, Pointclouds(points), weighted=False) # 这个和ECON的不太一样
closest_triangles = torch.gather(
self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
).view(-1, 3, 3)
bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)
feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
closest_normals = torch.gather(
feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
).view(-1, 3, 3)
shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))
pts2shoot_normals = points - shoot_verts
pts2shoot_normals = pts2shoot_normals / torch.norm(pts2shoot_normals, dim=-1, keepdim=True)
shoot_normals = ((closest_normals * bary_weights[:, :, None]).sum(1).unsqueeze(0))
shoot_normals = shoot_normals / torch.norm(shoot_normals, dim=-1, keepdim=True)
angles = (pts2shoot_normals * shoot_normals).sum(dim=-1).abs()
return (torch.sqrt(residues).unsqueeze(0), angles)