|
|
|
|
|
from packaging import version |
|
import torch |
|
import scipy |
|
import os |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from lib.common.config import cfg |
|
from lib.pymaf.utils.geometry import projection |
|
from lib.pymaf.core.path_config import MESH_DOWNSAMPLEING |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MAF_Extractor(nn.Module): |
|
''' Mesh-aligned Feature Extrator |
|
|
|
As discussed in the paper, we extract mesh-aligned features based on 2D projection of the mesh vertices. |
|
The features extrated from spatial feature maps will go through a MLP for dimension reduction. |
|
''' |
|
|
|
def __init__(self, device=torch.device('cuda')): |
|
super().__init__() |
|
|
|
self.device = device |
|
self.filters = [] |
|
self.num_views = 1 |
|
filter_channels = cfg.MODEL.PyMAF.MLP_DIM |
|
self.last_op = nn.ReLU(True) |
|
|
|
for l in range(0, len(filter_channels) - 1): |
|
if 0 != l: |
|
self.filters.append( |
|
nn.Conv1d(filter_channels[l] + filter_channels[0], |
|
filter_channels[l + 1], 1)) |
|
else: |
|
self.filters.append( |
|
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) |
|
|
|
self.add_module("conv%d" % l, self.filters[l]) |
|
|
|
self.im_feat = None |
|
self.cam = None |
|
|
|
|
|
|
|
smpl_mesh_graph = np.load(MESH_DOWNSAMPLEING, |
|
allow_pickle=True, |
|
encoding='latin1') |
|
|
|
A = smpl_mesh_graph['A'] |
|
U = smpl_mesh_graph['U'] |
|
D = smpl_mesh_graph['D'] |
|
|
|
|
|
ptD = [] |
|
for i in range(len(D)): |
|
d = scipy.sparse.coo_matrix(D[i]) |
|
i = torch.LongTensor(np.array([d.row, d.col])) |
|
v = torch.FloatTensor(d.data) |
|
ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) |
|
|
|
|
|
|
|
|
|
Dmap = torch.matmul(ptD[1].to_dense(), |
|
ptD[0].to_dense()) |
|
self.register_buffer('Dmap', Dmap) |
|
|
|
def reduce_dim(self, feature): |
|
''' |
|
Dimension reduction by multi-layer perceptrons |
|
:param feature: list of [B, C_s, N] point-wise features before dimension reduction |
|
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction |
|
''' |
|
y = feature |
|
tmpy = feature |
|
for i, f in enumerate(self.filters): |
|
y = self._modules['conv' + |
|
str(i)](y if i == 0 else torch.cat([y, tmpy], 1)) |
|
if i != len(self.filters) - 1: |
|
y = F.leaky_relu(y) |
|
if self.num_views > 1 and i == len(self.filters) // 2: |
|
y = y.view(-1, self.num_views, y.shape[1], |
|
y.shape[2]).mean(dim=1) |
|
tmpy = feature.view(-1, self.num_views, feature.shape[1], |
|
feature.shape[2]).mean(dim=1) |
|
|
|
y = self.last_op(y) |
|
|
|
y = y.view(y.shape[0], -1) |
|
return y |
|
|
|
def sampling(self, points, im_feat=None, z_feat=None): |
|
''' |
|
Given 2D points, sample the point-wise features for each point, |
|
the dimension of point-wise features will be reduced from C_s to C_p by MLP. |
|
Image features should be pre-computed before this call. |
|
:param points: [B, N, 2] image coordinates of points |
|
:im_feat: [B, C_s, H_s, W_s] spatial feature maps |
|
:return: [B, C_p x N] concatantion of point-wise features after dimension reduction |
|
''' |
|
if im_feat is None: |
|
im_feat = self.im_feat |
|
|
|
batch_size = im_feat.shape[0] |
|
|
|
if version.parse(torch.__version__) >= version.parse('1.3.0'): |
|
|
|
point_feat = torch.nn.functional.grid_sample( |
|
im_feat, points.unsqueeze(2), align_corners=True)[..., 0] |
|
else: |
|
point_feat = torch.nn.functional.grid_sample( |
|
im_feat, points.unsqueeze(2))[..., 0] |
|
|
|
mesh_align_feat = self.reduce_dim(point_feat) |
|
return mesh_align_feat |
|
|
|
def forward(self, p, s_feat=None, cam=None, **kwargs): |
|
''' Returns mesh-aligned features for the 3D mesh points. |
|
|
|
Args: |
|
p (tensor): [B, N_m, 3] mesh vertices |
|
s_feat (tensor): [B, C_s, H_s, W_s] spatial feature maps |
|
cam (tensor): [B, 3] camera |
|
Return: |
|
mesh_align_feat (tensor): [B, C_p x N_m] mesh-aligned features |
|
''' |
|
if cam is None: |
|
cam = self.cam |
|
p_proj_2d = projection(p, cam, retain_z=False) |
|
mesh_align_feat = self.sampling(p_proj_2d, s_feat) |
|
return mesh_align_feat |
|
|