File size: 2,460 Bytes
da48dbe
487ee6d
 
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
 
 
 
da48dbe
 
 
 
 
 
 
 
 
 
 
fb140f6
 
 
da48dbe
 
 
fb140f6
 
 
da48dbe
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from pytorch3d.structures import Meshes, Pointclouds

from lib.common.render_utils import face_vertices
from lib.dataset.Evaluator import point_mesh_distance
from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection


class PointFeat:
    def __init__(self, verts, faces):

        # verts [B, N_vert, 3]
        # faces [B, N_face, 3]
        # triangles [B, N_face, 3, 3]

        self.Bsize = verts.shape[0]
        self.device = verts.device
        self.faces = faces

        # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth
        # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929
        # 2. fill mouth holes with 30 more faces

        if verts.shape[1] == 10475:
            faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask]
            mouth_faces = (
                torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1,
                                                                             1).to(self.device)
            )
            self.faces = torch.cat([faces, mouth_faces], dim=1).long()

        self.verts = verts.float()
        self.triangles = face_vertices(self.verts, self.faces)
        self.mesh = Meshes(self.verts, self.faces).to(self.device)

    def query(self, points):

        points = points.float()
        residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False)

        closest_triangles = torch.gather(
            self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
        ).view(-1, 3, 3)
        bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles)

        feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces)
        closest_normals = torch.gather(
            feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3)
        ).view(-1, 3, 3)
        shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0))

        pts2shoot_normals = points - shoot_verts
        pts2shoot_normals = pts2shoot_normals / torch.norm(pts2shoot_normals, dim=-1, keepdim=True)

        shoot_normals = ((closest_normals * bary_weights[:, :, None]).sum(1).unsqueeze(0))
        shoot_normals = shoot_normals / torch.norm(shoot_normals, dim=-1, keepdim=True)
        angles = (pts2shoot_normals * shoot_normals).sum(dim=-1).abs()

        return (torch.sqrt(residues).unsqueeze(0), angles)