Spaces:
Running
on
L40S
Running
on
L40S
from pytorch3d.structures import Meshes, Pointclouds | |
import torch.nn.functional as F | |
import torch | |
from lib.common.render_utils import face_vertices | |
from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection | |
from kaolin.ops.mesh import check_sign, face_normals | |
from kaolin.metrics.trianglemesh import point_to_mesh_distance | |
from lib.dataset.Evaluator import point_mesh_distance | |
from lib.dataset.ECON_Evaluator import econ_point_mesh_distance | |
def distance_matrix(x, y=None, p = 2): #pairwise distance of vectors | |
y = x if type(y) == type(None) else y | |
n = x.size(0) | |
m = y.size(0) | |
d = x.size(1) | |
x = x.unsqueeze(1).expand(n, m, d) | |
y = y.unsqueeze(0).expand(n, m, d) | |
dist = torch.norm(x - y, dim=-1) if torch.__version__ >= '1.7.0' else torch.pow(x - y, p).sum(2)**(1/p) | |
return dist | |
class NN(): | |
def __init__(self, X = None, Y = None, p = 2): | |
self.p = p | |
self.train(X, Y) | |
def train(self, X, Y): | |
self.train_pts = X | |
self.train_label = Y | |
def __call__(self, x): | |
return self.predict(x) | |
def predict(self, x): | |
if type(self.train_pts) == type(None) or type(self.train_label) == type(None): | |
name = self.__class__.__name__ | |
raise RuntimeError(f"{name} wasn't trained. Need to execute {name}.train() first") | |
dist=[] | |
chunk=10000 | |
for i in range(0,x.shape[0],chunk): | |
dist.append(distance_matrix(x[i:i+chunk], self.train_pts, self.p)) | |
dist = torch.cat(dist, dim=0) | |
labels = torch.argmin(dist, dim=1) | |
return self.train_label[labels],labels | |
class PointFeat: | |
def __init__(self, verts, faces): | |
# verts [B, N_vert, 3] | |
# faces [B, N_face, 3] | |
# triangles [B, N_face, 3, 3] | |
self.Bsize = verts.shape[0] | |
self.mesh = Meshes(verts, faces) | |
self.device = verts.device | |
self.faces = faces | |
# SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth | |
# 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 | |
# 2. fill mouth holes with 30 more faces | |
if verts.shape[1] == 10475: | |
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] | |
mouth_faces = (torch.as_tensor( | |
SMPLX().smplx_mouth_fid).unsqueeze(0).repeat( | |
self.Bsize, 1, 1).to(self.device)) | |
self.faces = torch.cat([faces, mouth_faces], dim=1).long() | |
self.verts = verts | |
self.triangles = face_vertices(self.verts, self.faces) | |
def get_face_normals(self): | |
return face_normals(self.verts, self.faces) | |
def get_nearest_point(self,points): | |
# points [1, N, 3] | |
# find nearest point on mesh | |
#devices = points.device | |
points=points.squeeze(0) | |
nn_class=NN(X=self.verts.squeeze(0),Y=self.verts.squeeze(0),p=2) | |
nearest_points,nearest_points_ind=nn_class.predict(points) | |
# closest_triangles = torch.gather( | |
# self.triangles, 1, | |
# pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) | |
# bary_weights = barycentric_coordinates_of_projection( | |
# points.view(-1, 3), closest_triangles) | |
# bary_weights=F.normalize(bary_weights, p=2, dim=1) | |
# normals = face_normals(self.triangles) | |
# # make the lenght of the normal is 1 | |
# normals = F.normalize(normals, p=2, dim=2) | |
# # get the normal of the closest triangle | |
# closest_normals = torch.gather( | |
# normals, 1, | |
# pts_ind[:, :, None].expand(-1, -1, 3)).view(-1, 3) | |
return nearest_points,nearest_points_ind # on cpu | |
def query_barycentirc_feats(self,points,feats): | |
# feats [B,N,C] | |
residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles) | |
closest_triangles = torch.gather( | |
self.triangles, 1, | |
pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) | |
bary_weights = barycentric_coordinates_of_projection( | |
points.view(-1, 3), closest_triangles) | |
feat_arr=feats | |
feat_dim = feat_arr.shape[-1] | |
feat_tri = face_vertices(feat_arr, self.faces) | |
closest_feats = torch.gather( # query点距离最近的face的三个点的feature | |
feat_tri, 1, | |
pts_ind[:, :, None, | |
None].expand(-1, -1, 3, | |
feat_dim)).view(-1, 3, feat_dim) | |
pts_feats = ((closest_feats * | |
bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和 | |
return pts_feats.view(self.Bsize,-1,feat_dim) | |
def query(self, points, feats={}): | |
# points [B, N, 3] | |
# feats {'feat_name': [B, N, C]} | |
del_keys = ["smpl_verts", "smpl_faces", "smpl_joint","smpl_sample_id"] | |
residues, pts_ind, _ = point_to_mesh_distance(points, self.triangles) | |
closest_triangles = torch.gather( | |
self.triangles, 1, | |
pts_ind[:, :, None, None].expand(-1, -1, 3, 3)).view(-1, 3, 3) | |
bary_weights = barycentric_coordinates_of_projection( | |
points.view(-1, 3), closest_triangles) | |
out_dict = {} | |
for feat_key in feats.keys(): | |
if feat_key in del_keys: | |
continue | |
elif feats[feat_key] is not None: | |
feat_arr = feats[feat_key] | |
feat_dim = feat_arr.shape[-1] | |
feat_tri = face_vertices(feat_arr, self.faces) | |
closest_feats = torch.gather( # query点距离最近的face的三个点的feature | |
feat_tri, 1, | |
pts_ind[:, :, None, | |
None].expand(-1, -1, 3, | |
feat_dim)).view(-1, 3, feat_dim) | |
pts_feats = ((closest_feats * | |
bary_weights[:, :, None]).sum(1).unsqueeze(0)) # 用barycentric weight加权求和 | |
out_dict[feat_key.split("_")[1]] = pts_feats | |
else: | |
out_dict[feat_key.split("_")[1]] = None | |
if "sdf" in out_dict.keys(): | |
pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3)) | |
pts_signs = 2.0 * ( | |
check_sign(self.verts, self.faces[0], points).float() - 0.5) | |
pts_sdf = (pts_dist * pts_signs).unsqueeze(-1) | |
out_dict["sdf"] = pts_sdf | |
if "vis" in out_dict.keys(): | |
out_dict["vis"] = out_dict["vis"].ge(1e-1).float() | |
if "norm" in out_dict.keys(): | |
pts_norm = out_dict["norm"] * torch.tensor([-1.0, 1.0, -1.0]).to( | |
self.device) | |
out_dict["norm"] = F.normalize(pts_norm, dim=2) | |
if "cmap" in out_dict.keys(): | |
out_dict["cmap"] = out_dict["cmap"].clamp_(min=0.0, max=1.0) | |
for out_key in out_dict.keys(): | |
out_dict[out_key] = out_dict[out_key].view( | |
self.Bsize, -1, out_dict[out_key].shape[-1]) | |
return out_dict | |
class ECON_PointFeat: | |
def __init__(self, verts, faces): | |
# verts [B, N_vert, 3] | |
# faces [B, N_face, 3] | |
# triangles [B, N_face, 3, 3] | |
self.Bsize = verts.shape[0] | |
self.device = verts.device | |
self.faces = faces | |
# SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth | |
# 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 | |
# 2. fill mouth holes with 30 more faces | |
if verts.shape[1] == 10475: | |
faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] | |
mouth_faces = ( | |
torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, | |
1).to(self.device) | |
) | |
self.faces = torch.cat([faces, mouth_faces], dim=1).long() | |
self.verts = verts.float() | |
self.triangles = face_vertices(self.verts, self.faces) | |
self.mesh = Meshes(self.verts, self.faces).to(self.device) | |
def query(self, points): | |
points = points.float() | |
residues, pts_ind = econ_point_mesh_distance(self.mesh, Pointclouds(points), weighted=False) # 这个和ECON的不太一样 | |
closest_triangles = torch.gather( | |
self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) | |
).view(-1, 3, 3) | |
bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles) | |
feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces) | |
closest_normals = torch.gather( | |
feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) | |
).view(-1, 3, 3) | |
shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0)) | |
pts2shoot_normals = points - shoot_verts | |
pts2shoot_normals = pts2shoot_normals / torch.norm(pts2shoot_normals, dim=-1, keepdim=True) | |
shoot_normals = ((closest_normals * bary_weights[:, :, None]).sum(1).unsqueeze(0)) | |
shoot_normals = shoot_normals / torch.norm(shoot_normals, dim=-1, keepdim=True) | |
angles = (pts2shoot_normals * shoot_normals).sum(dim=-1).abs() | |
return (torch.sqrt(residues).unsqueeze(0), angles) |