|
|
|
|
|
from typing import Optional |
|
from dataclasses import dataclass |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import pickle |
|
from lib.smplx import SMPL as _SMPL |
|
from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer |
|
from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes |
|
from lib.smplx.body_models import SMPLXOutput |
|
import json |
|
|
|
from lib.pymafx.core import path_config, constants |
|
|
|
SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS |
|
SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR |
|
|
|
|
|
@dataclass |
|
class ModelOutput(SMPLXOutput): |
|
smpl_joints: Optional[torch.Tensor] = None |
|
joints_J19: Optional[torch.Tensor] = None |
|
smplx_vertices: Optional[torch.Tensor] = None |
|
flame_vertices: Optional[torch.Tensor] = None |
|
lhand_vertices: Optional[torch.Tensor] = None |
|
rhand_vertices: Optional[torch.Tensor] = None |
|
lhand_joints: Optional[torch.Tensor] = None |
|
rhand_joints: Optional[torch.Tensor] = None |
|
face_joints: Optional[torch.Tensor] = None |
|
lfoot_joints: Optional[torch.Tensor] = None |
|
rfoot_joints: Optional[torch.Tensor] = None |
|
|
|
|
|
class SMPL(_SMPL): |
|
""" Extension of the official SMPL implementation to support more joints """ |
|
def __init__( |
|
self, |
|
create_betas=False, |
|
create_global_orient=False, |
|
create_body_pose=False, |
|
create_transl=False, |
|
*args, |
|
**kwargs |
|
): |
|
super().__init__( |
|
create_betas=create_betas, |
|
create_global_orient=create_global_orient, |
|
create_body_pose=create_body_pose, |
|
create_transl=create_transl, |
|
*args, |
|
**kwargs |
|
) |
|
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
|
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
|
self.register_buffer( |
|
'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
|
) |
|
self.joint_map = torch.tensor(joints, dtype=torch.long) |
|
|
|
|
|
|
|
tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0)) |
|
self.register_buffer('tpose_joints', tpose_joints) |
|
|
|
def forward(self, *args, **kwargs): |
|
kwargs['get_skin'] = True |
|
smpl_output = super().forward(*args, **kwargs) |
|
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) |
|
|
|
vertices = smpl_output.vertices |
|
joints = torch.cat([smpl_output.joints, extra_joints], dim=1) |
|
smpl_joints = smpl_output.joints[:, :24] |
|
joints = joints[:, self.joint_map, :] |
|
joints_J24 = joints[:, -24:, :] |
|
joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
|
output = ModelOutput( |
|
vertices=vertices, |
|
global_orient=smpl_output.global_orient, |
|
body_pose=smpl_output.body_pose, |
|
joints=joints, |
|
joints_J19=joints_J19, |
|
smpl_joints=smpl_joints, |
|
betas=smpl_output.betas, |
|
full_pose=smpl_output.full_pose |
|
) |
|
return output |
|
|
|
def get_global_rotation( |
|
self, |
|
global_orient: Optional[torch.Tensor] = None, |
|
body_pose: Optional[torch.Tensor] = None, |
|
**kwargs |
|
): |
|
''' |
|
Forward pass for the SMPLX model |
|
|
|
Parameters |
|
---------- |
|
global_orient: torch.tensor, optional, shape Bx3x3 |
|
If given, ignore the member variable and use it as the global |
|
rotation of the body. Useful if someone wishes to predicts this |
|
with an external model. It is expected to be in rotation matrix |
|
format. (default=None) |
|
body_pose: torch.tensor, optional, shape BxJx3x3 |
|
If given, ignore the member variable `body_pose` and use it |
|
instead. For example, it can used if someone predicts the |
|
pose of the body joints are predicted from some external model. |
|
It should be a tensor that contains joint rotations in |
|
rotation matrix format. (default=None) |
|
Returns |
|
------- |
|
output: Global rotation matrix |
|
''' |
|
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
|
|
|
model_vars = [global_orient, body_pose] |
|
batch_size = 1 |
|
for var in model_vars: |
|
if var is None: |
|
continue |
|
batch_size = max(batch_size, len(var)) |
|
|
|
if global_orient is None: |
|
global_orient = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
|
-1).contiguous() |
|
if body_pose is None: |
|
body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
|
batch_size, self.NUM_BODY_JOINTS, -1, -1 |
|
).contiguous() |
|
|
|
|
|
full_pose = torch.cat( |
|
[global_orient.reshape(-1, 1, 3, 3), |
|
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)], |
|
dim=1 |
|
) |
|
|
|
rot_mats = full_pose.view(batch_size, -1, 3, 3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1) |
|
|
|
rel_joints = joints.clone() |
|
rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
|
|
|
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
|
rel_joints.reshape(-1, 3, |
|
1)).reshape(-1, joints.shape[1], 4, 4) |
|
|
|
transform_chain = [transforms_mat[:, 0]] |
|
for i in range(1, self.parents.shape[0]): |
|
|
|
|
|
curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
|
transform_chain.append(curr_res) |
|
|
|
transforms = torch.stack(transform_chain, dim=1) |
|
|
|
global_rotmat = transforms[:, :, :3, :3] |
|
|
|
|
|
posed_joints = transforms[:, :, :3, 3] |
|
|
|
return global_rotmat, posed_joints |
|
|
|
|
|
class SMPLX(SMPLXLayer): |
|
""" Extension of the official SMPLX implementation to support more functions """ |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def get_global_rotation( |
|
self, |
|
global_orient: Optional[torch.Tensor] = None, |
|
body_pose: Optional[torch.Tensor] = None, |
|
left_hand_pose: Optional[torch.Tensor] = None, |
|
right_hand_pose: Optional[torch.Tensor] = None, |
|
jaw_pose: Optional[torch.Tensor] = None, |
|
leye_pose: Optional[torch.Tensor] = None, |
|
reye_pose: Optional[torch.Tensor] = None, |
|
**kwargs |
|
): |
|
''' |
|
Forward pass for the SMPLX model |
|
|
|
Parameters |
|
---------- |
|
global_orient: torch.tensor, optional, shape Bx3x3 |
|
If given, ignore the member variable and use it as the global |
|
rotation of the body. Useful if someone wishes to predicts this |
|
with an external model. It is expected to be in rotation matrix |
|
format. (default=None) |
|
betas: torch.tensor, optional, shape BxN_b |
|
If given, ignore the member variable `betas` and use it |
|
instead. For example, it can used if shape parameters |
|
`betas` are predicted from some external model. |
|
(default=None) |
|
expression: torch.tensor, optional, shape BxN_e |
|
Expression coefficients. |
|
For example, it can used if expression parameters |
|
`expression` are predicted from some external model. |
|
body_pose: torch.tensor, optional, shape BxJx3x3 |
|
If given, ignore the member variable `body_pose` and use it |
|
instead. For example, it can used if someone predicts the |
|
pose of the body joints are predicted from some external model. |
|
It should be a tensor that contains joint rotations in |
|
rotation matrix format. (default=None) |
|
left_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
|
If given, contains the pose of the left hand. |
|
It should be a tensor that contains joint rotations in |
|
rotation matrix format. (default=None) |
|
right_hand_pose: torch.tensor, optional, shape Bx15x3x3 |
|
If given, contains the pose of the right hand. |
|
It should be a tensor that contains joint rotations in |
|
rotation matrix format. (default=None) |
|
jaw_pose: torch.tensor, optional, shape Bx3x3 |
|
Jaw pose. It should either joint rotations in |
|
rotation matrix format. |
|
transl: torch.tensor, optional, shape Bx3 |
|
Translation vector of the body. |
|
For example, it can used if the translation |
|
`transl` is predicted from some external model. |
|
(default=None) |
|
return_verts: bool, optional |
|
Return the vertices. (default=True) |
|
return_full_pose: bool, optional |
|
Returns the full pose vector (default=False) |
|
Returns |
|
------- |
|
output: ModelOutput |
|
A data class that contains the posed vertices and joints |
|
''' |
|
device, dtype = self.shapedirs.device, self.shapedirs.dtype |
|
|
|
model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose] |
|
batch_size = 1 |
|
for var in model_vars: |
|
if var is None: |
|
continue |
|
batch_size = max(batch_size, len(var)) |
|
|
|
if global_orient is None: |
|
global_orient = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
|
-1).contiguous() |
|
if body_pose is None: |
|
body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand( |
|
batch_size, self.NUM_BODY_JOINTS, -1, -1 |
|
).contiguous() |
|
if left_hand_pose is None: |
|
left_hand_pose = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1, |
|
-1).contiguous() |
|
if right_hand_pose is None: |
|
right_hand_pose = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, |
|
3).expand(batch_size, 15, -1, |
|
-1).contiguous() |
|
if jaw_pose is None: |
|
jaw_pose = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
|
-1).contiguous() |
|
if leye_pose is None: |
|
leye_pose = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
|
-1).contiguous() |
|
if reye_pose is None: |
|
reye_pose = torch.eye(3, device=device, |
|
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1, |
|
-1).contiguous() |
|
|
|
|
|
full_pose = torch.cat( |
|
[ |
|
global_orient.reshape(-1, 1, 3, 3), |
|
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3), |
|
jaw_pose.reshape(-1, 1, 3, 3), |
|
leye_pose.reshape(-1, 1, 3, 3), |
|
reye_pose.reshape(-1, 1, 3, 3), |
|
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3), |
|
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3) |
|
], |
|
dim=1 |
|
) |
|
|
|
rot_mats = full_pose.view(batch_size, -1, 3, 3) |
|
|
|
|
|
|
|
joints = vertices2joints( |
|
self.J_regressor, |
|
self.v_template.unsqueeze(0).expand(batch_size, -1, -1) |
|
) |
|
|
|
joints = torch.unsqueeze(joints, dim=-1) |
|
|
|
rel_joints = joints.clone() |
|
rel_joints[:, 1:] -= joints[:, self.parents[1:]] |
|
|
|
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), |
|
rel_joints.reshape(-1, 3, |
|
1)).reshape(-1, joints.shape[1], 4, 4) |
|
|
|
transform_chain = [transforms_mat[:, 0]] |
|
for i in range(1, self.parents.shape[0]): |
|
|
|
|
|
curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i]) |
|
transform_chain.append(curr_res) |
|
|
|
transforms = torch.stack(transform_chain, dim=1) |
|
|
|
global_rotmat = transforms[:, :, :3, :3] |
|
|
|
|
|
posed_joints = transforms[:, :, :3, 3] |
|
|
|
return global_rotmat, posed_joints |
|
|
|
|
|
class SMPLX_ALL(nn.Module): |
|
""" Extension of the official SMPLX implementation to support more joints """ |
|
def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs): |
|
super().__init__() |
|
numBetas = 10 |
|
self.use_face_contour = use_face_contour |
|
if all_gender: |
|
self.genders = ['male', 'female', 'neutral'] |
|
else: |
|
self.genders = ['neutral'] |
|
for gender in self.genders: |
|
assert gender in ['male', 'female', 'neutral'] |
|
self.model_dict = nn.ModuleDict( |
|
{ |
|
gender: SMPLX( |
|
path_config.SMPL_MODEL_DIR, |
|
gender=gender, |
|
ext='npz', |
|
num_betas=numBetas, |
|
use_pca=False, |
|
batch_size=batch_size, |
|
use_face_contour=use_face_contour, |
|
num_pca_comps=45, |
|
**kwargs |
|
) |
|
for gender in self.genders |
|
} |
|
) |
|
self.model_neutral = self.model_dict['neutral'] |
|
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES] |
|
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA) |
|
self.register_buffer( |
|
'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32) |
|
) |
|
self.joint_map = torch.tensor(joints, dtype=torch.long) |
|
|
|
smplx_to_smpl = pickle.load( |
|
open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb') |
|
) |
|
self.register_buffer( |
|
'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32) |
|
) |
|
|
|
smpl2limb_vert_faces = get_partial_smpl('smpl') |
|
self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long() |
|
self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long() |
|
|
|
|
|
smplx2lhand_joints = [ |
|
constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES |
|
] |
|
smplx2rhand_joints = [ |
|
constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES |
|
] |
|
self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long) |
|
self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long) |
|
|
|
|
|
smplx2lfoot_joints = [ |
|
constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES |
|
] |
|
smplx2rfoot_joints = [ |
|
constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES |
|
] |
|
self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long) |
|
self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long) |
|
|
|
for g in self.genders: |
|
J_template = torch.einsum( |
|
'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template] |
|
) |
|
J_dirs = torch.einsum( |
|
'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs] |
|
) |
|
|
|
self.register_buffer(f'{g}_J_template', J_template) |
|
self.register_buffer(f'{g}_J_dirs', J_dirs) |
|
|
|
def forward(self, *args, **kwargs): |
|
batch_size = kwargs['body_pose'].shape[0] |
|
kwargs['get_skin'] = True |
|
if 'pose2rot' not in kwargs: |
|
kwargs['pose2rot'] = True |
|
if 'gender' not in kwargs: |
|
kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device) |
|
|
|
|
|
pose_keys = [ |
|
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose', |
|
'leye_pose', 'reye_pose' |
|
] |
|
param_keys = ['betas'] + pose_keys |
|
if kwargs['pose2rot']: |
|
for key in pose_keys: |
|
if key in kwargs: |
|
|
|
|
|
|
|
|
|
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
|
[batch_size, -1, 3, 3] |
|
) |
|
if kwargs['body_pose'].shape[1] == 23: |
|
|
|
kwargs['body_pose'] = kwargs['body_pose'][:, :21] |
|
gender_idx_list = [] |
|
smplx_vertices, smplx_joints = [], [] |
|
for gi, g in enumerate(['male', 'female', 'neutral']): |
|
gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
|
if len(gender_idx) == 0: |
|
continue |
|
gender_idx_list.extend([int(idx) for idx in gender_idx]) |
|
gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']} |
|
gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
|
gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs) |
|
smplx_vertices.append(gender_smplx_output.vertices) |
|
smplx_joints.append(gender_smplx_output.joints) |
|
|
|
idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
|
idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device) |
|
|
|
smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange] |
|
smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
|
|
|
|
|
lhand_joints = smplx_joints[:, self.smplx2lh_joint_map] |
|
rhand_joints = smplx_joints[:, self.smplx2rh_joint_map] |
|
|
|
face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:] |
|
|
|
lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map] |
|
rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map] |
|
|
|
smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices) |
|
lhand_vertices = smpl_vertices[:, self.smpl2lhand] |
|
rhand_vertices = smpl_vertices[:, self.smpl2rhand] |
|
extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices) |
|
|
|
smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45] |
|
joints = torch.cat([smplx_j45, extra_joints], dim=1) |
|
smpl_joints = smplx_j45[:, :24] |
|
joints = joints[:, self.joint_map, :] |
|
joints_J24 = joints[:, -24:, :] |
|
joints_J19 = joints_J24[:, constants.J24_TO_J19, :] |
|
output = ModelOutput( |
|
vertices=smpl_vertices, |
|
smplx_vertices=smplx_vertices, |
|
lhand_vertices=lhand_vertices, |
|
rhand_vertices=rhand_vertices, |
|
|
|
|
|
joints=joints, |
|
joints_J19=joints_J19, |
|
smpl_joints=smpl_joints, |
|
|
|
|
|
lhand_joints=lhand_joints, |
|
rhand_joints=rhand_joints, |
|
lfoot_joints=lfoot_joints, |
|
rfoot_joints=rfoot_joints, |
|
face_joints=face_joints, |
|
) |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tpose(self, betas=None, gender=None): |
|
kwargs = {} |
|
if betas is None: |
|
betas = torch.zeros(1, 10).to(self.J_regressor_extra.device) |
|
kwargs['betas'] = betas |
|
|
|
batch_size = kwargs['betas'].shape[0] |
|
device = kwargs['betas'].device |
|
|
|
if gender is None: |
|
kwargs['gender'] = 2 * torch.ones(batch_size).to(device) |
|
else: |
|
kwargs['gender'] = gender |
|
|
|
param_keys = ['betas'] |
|
|
|
gender_idx_list = [] |
|
smplx_joints = [] |
|
for gi, g in enumerate(['male', 'female', 'neutral']): |
|
gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0]) |
|
if len(gender_idx) == 0: |
|
continue |
|
gender_idx_list.extend([int(idx) for idx in gender_idx]) |
|
gender_kwargs = {} |
|
gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs}) |
|
|
|
J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes( |
|
gender_kwargs['betas'], getattr(self, f'{g}_J_dirs') |
|
) |
|
|
|
smplx_joints.append(J) |
|
|
|
idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))] |
|
idx_rearrange = torch.tensor(idx_rearrange).long().to(device) |
|
|
|
smplx_joints = torch.cat(smplx_joints)[idx_rearrange] |
|
|
|
return smplx_joints |
|
|
|
|
|
class MANO(MANOLayer): |
|
""" Extension of the official MANO implementation to support more joints """ |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, *args, **kwargs): |
|
if 'pose2rot' not in kwargs: |
|
kwargs['pose2rot'] = True |
|
pose_keys = ['global_orient', 'right_hand_pose'] |
|
batch_size = kwargs['global_orient'].shape[0] |
|
if kwargs['pose2rot']: |
|
for key in pose_keys: |
|
if key in kwargs: |
|
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
|
[batch_size, -1, 3, 3] |
|
) |
|
kwargs['hand_pose'] = kwargs.pop('right_hand_pose') |
|
mano_output = super().forward(*args, **kwargs) |
|
th_verts = mano_output.vertices |
|
th_jtr = mano_output.joints |
|
|
|
|
|
|
|
tips = th_verts[:, [745, 317, 445, 556, 673]] |
|
th_jtr = torch.cat([th_jtr, tips], 1) |
|
|
|
th_jtr = th_jtr[:, |
|
[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] |
|
output = ModelOutput( |
|
rhand_vertices=th_verts, |
|
rhand_joints=th_jtr, |
|
) |
|
return output |
|
|
|
|
|
class FLAME(FLAMELayer): |
|
""" Extension of the official FLAME implementation to support more joints """ |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, *args, **kwargs): |
|
if 'pose2rot' not in kwargs: |
|
kwargs['pose2rot'] = True |
|
pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose'] |
|
batch_size = kwargs['global_orient'].shape[0] |
|
if kwargs['pose2rot']: |
|
for key in pose_keys: |
|
if key in kwargs: |
|
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view( |
|
[batch_size, -1, 3, 3] |
|
) |
|
flame_output = super().forward(*args, **kwargs) |
|
output = ModelOutput( |
|
flame_vertices=flame_output.vertices, |
|
face_joints=flame_output.joints[:, 5:], |
|
) |
|
return output |
|
|
|
|
|
class SMPL_Family(): |
|
def __init__(self, model_type='smpl', *args, **kwargs): |
|
if model_type == 'smpl': |
|
self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs) |
|
elif model_type == 'smplx': |
|
self.model = SMPLX_ALL(*args, **kwargs) |
|
elif model_type == 'mano': |
|
self.model = MANO( |
|
model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs |
|
) |
|
elif model_type == 'flame': |
|
self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def get_tpose(self, *args, **kwargs): |
|
return self.model.get_tpose(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_smpl_faces(): |
|
smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1) |
|
return smpl.faces |
|
|
|
|
|
def get_smplx_faces(): |
|
smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1) |
|
return smplx.faces |
|
|
|
|
|
def get_mano_faces(hand_type='right'): |
|
assert hand_type in ['right', 'left'] |
|
is_rhand = True if hand_type == 'right' else False |
|
mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand) |
|
|
|
return mano.faces |
|
|
|
|
|
def get_flame_faces(): |
|
flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
|
|
|
return flame.faces |
|
|
|
|
|
def get_model_faces(type='smpl'): |
|
if type == 'smpl': |
|
return get_smpl_faces() |
|
elif type == 'smplx': |
|
return get_smplx_faces() |
|
elif type == 'mano': |
|
return get_mano_faces() |
|
elif type == 'flame': |
|
return get_flame_faces() |
|
|
|
|
|
def get_model_tpose(type='smpl'): |
|
if type == 'smpl': |
|
return get_smpl_tpose() |
|
elif type == 'smplx': |
|
return get_smplx_tpose() |
|
elif type == 'mano': |
|
return get_mano_tpose() |
|
elif type == 'flame': |
|
return get_flame_tpose() |
|
|
|
|
|
def get_smpl_tpose(): |
|
smpl = SMPL( |
|
create_betas=True, |
|
create_global_orient=True, |
|
create_body_pose=True, |
|
model_path=SMPL_MODEL_DIR, |
|
batch_size=1 |
|
) |
|
vertices = smpl().vertices[0] |
|
return vertices.detach() |
|
|
|
|
|
def get_smpl_tpose_joint(): |
|
smpl = SMPL( |
|
create_betas=True, |
|
create_global_orient=True, |
|
create_body_pose=True, |
|
model_path=SMPL_MODEL_DIR, |
|
batch_size=1 |
|
) |
|
tpose_joint = smpl().smpl_joints[0] |
|
return tpose_joint.detach() |
|
|
|
|
|
def get_smplx_tpose(): |
|
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
|
vertices = smplx().vertices[0] |
|
return vertices |
|
|
|
|
|
def get_smplx_tpose_joint(): |
|
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1) |
|
tpose_joint = smplx().joints[0] |
|
return tpose_joint |
|
|
|
|
|
def get_mano_tpose(): |
|
mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True) |
|
vertices = mano(global_orient=torch.zeros(1, 3), |
|
right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0] |
|
return vertices |
|
|
|
|
|
def get_flame_tpose(): |
|
flame = FLAME(SMPL_MODEL_DIR, batch_size=1) |
|
vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0] |
|
return vertices |
|
|
|
|
|
def get_part_joints(smpl_joints): |
|
batch_size = smpl_joints.shape[0] |
|
|
|
|
|
|
|
one_seg_pairs = [ |
|
(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17) |
|
] |
|
two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)] |
|
|
|
one_seg_pairs.extend(two_seg_pairs) |
|
|
|
single_joints = [(10), (11), (15), (22), (23)] |
|
|
|
part_joints = [] |
|
|
|
for j_p in one_seg_pairs: |
|
new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True) |
|
part_joints.append(new_joint) |
|
|
|
for j_p in single_joints: |
|
part_joints.append(smpl_joints[:, j_p:j_p + 1]) |
|
|
|
part_joints = torch.cat(part_joints, dim=1) |
|
|
|
return part_joints |
|
|
|
|
|
def get_partial_smpl(body_model='smpl', device=torch.device('cuda')): |
|
|
|
body_model_faces = get_model_faces(body_model) |
|
body_model_num_verts = len(get_model_tpose(body_model)) |
|
|
|
part_vert_faces = {} |
|
|
|
for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']: |
|
part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part) |
|
if os.path.exists(part_vid_fname): |
|
part_vids = np.load(part_vid_fname) |
|
part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']} |
|
else: |
|
if part in ['lhand', 'rhand']: |
|
with open( |
|
os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb' |
|
) as json_file: |
|
smplx_mano_id = pickle.load(json_file) |
|
with open( |
|
os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb' |
|
) as json_file: |
|
smplx_smpl_id = pickle.load(json_file) |
|
|
|
smplx_tpose = get_smplx_tpose() |
|
smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose) |
|
|
|
if part == 'lhand': |
|
mano_vert = smplx_tpose[smplx_mano_id['left_hand']] |
|
elif part == 'rhand': |
|
mano_vert = smplx_tpose[smplx_mano_id['right_hand']] |
|
|
|
smpl2mano_id = [] |
|
for vert in mano_vert: |
|
v_diff = smpl_tpose - vert |
|
v_diff = torch.sum(v_diff * v_diff, dim=1) |
|
v_closest = torch.argmin(v_diff) |
|
smpl2mano_id.append(int(v_closest)) |
|
|
|
smpl2mano_vids = np.array(smpl2mano_id).astype(np.longlong) |
|
mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left' |
|
).astype(np.longlong) |
|
|
|
np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces) |
|
part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces} |
|
|
|
elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']: |
|
with open( |
|
os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)), |
|
'rb' |
|
) as json_file: |
|
smplx_part_id = json.load(json_file) |
|
|
|
|
|
|
|
|
|
if part == 'face': |
|
selected_body_part = ['head'] |
|
elif part == 'arm': |
|
selected_body_part = [ |
|
'rightHand', |
|
'leftArm', |
|
'leftShoulder', |
|
'rightShoulder', |
|
'rightArm', |
|
'leftHandIndex1', |
|
'rightHandIndex1', |
|
'leftForeArm', |
|
'rightForeArm', |
|
'leftHand', |
|
] |
|
|
|
elif part == 'forearm': |
|
selected_body_part = [ |
|
'rightHand', |
|
'leftHandIndex1', |
|
'rightHandIndex1', |
|
'leftForeArm', |
|
'rightForeArm', |
|
'leftHand', |
|
] |
|
elif part == 'arm_eval': |
|
selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm'] |
|
elif part == 'larm': |
|
|
|
selected_body_part = ['leftForeArm'] |
|
elif part == 'rarm': |
|
|
|
selected_body_part = ['rightForeArm'] |
|
|
|
part_body_idx = [] |
|
for k in selected_body_part: |
|
part_body_idx.extend(smplx_part_id[k]) |
|
|
|
part_body_fid = [] |
|
for f_id, face in enumerate(body_model_faces): |
|
if any(f in part_body_idx for f in face): |
|
part_body_fid.append(f_id) |
|
|
|
smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
|
|
|
mesh_vid_raw = np.arange(body_model_num_verts) |
|
head_vid_new = np.arange(len(smpl2head_vids)) |
|
mesh_vid_raw[smpl2head_vids] = head_vid_new |
|
|
|
head_faces = body_model_faces[part_body_fid] |
|
head_faces = mesh_vid_raw[head_faces].astype(np.longlong) |
|
|
|
np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces) |
|
part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces} |
|
|
|
elif part in ['lwrist', 'rwrist']: |
|
|
|
if body_model == 'smplx': |
|
body_model_verts = get_smplx_tpose() |
|
tpose_joint = get_smplx_tpose_joint() |
|
elif body_model == 'smpl': |
|
body_model_verts = get_smpl_tpose() |
|
tpose_joint = get_smpl_tpose_joint() |
|
|
|
wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21] |
|
|
|
dist = 0.005 |
|
wrist_vids = [] |
|
for vid, vt in enumerate(body_model_verts): |
|
|
|
v_j_dist = torch.sum((vt - wrist_joint)**2) |
|
|
|
if v_j_dist < dist: |
|
wrist_vids.append(vid) |
|
|
|
wrist_vids = np.array(wrist_vids) |
|
|
|
part_body_fid = [] |
|
for f_id, face in enumerate(body_model_faces): |
|
if any(f in wrist_vids for f in face): |
|
part_body_fid.append(f_id) |
|
|
|
smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong) |
|
|
|
mesh_vid_raw = np.arange(body_model_num_verts) |
|
part_vid_new = np.arange(len(smpl2part_vids)) |
|
mesh_vid_raw[smpl2part_vids] = part_vid_new |
|
|
|
part_faces = body_model_faces[part_body_fid] |
|
part_faces = mesh_vid_raw[part_faces].astype(np.longlong) |
|
|
|
np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces) |
|
part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return part_vert_faces |
|
|