fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
38.4 kB
# This script is extended based on https://github.com/nkolot/SPIN/blob/master/models/smpl.py
from typing import Optional
from dataclasses import dataclass
import os
import torch
import torch.nn as nn
import numpy as np
import pickle
from lib.smplx import SMPL as _SMPL
from lib.smplx import SMPLXLayer, MANOLayer, FLAMELayer
from lib.smplx.lbs import batch_rodrigues, transform_mat, vertices2joints, blend_shapes
from lib.smplx.body_models import SMPLXOutput
import json
from lib.pymafx.core import path_config, constants
SMPL_MEAN_PARAMS = path_config.SMPL_MEAN_PARAMS
SMPL_MODEL_DIR = path_config.SMPL_MODEL_DIR
@dataclass
class ModelOutput(SMPLXOutput):
smpl_joints: Optional[torch.Tensor] = None
joints_J19: Optional[torch.Tensor] = None
smplx_vertices: Optional[torch.Tensor] = None
flame_vertices: Optional[torch.Tensor] = None
lhand_vertices: Optional[torch.Tensor] = None
rhand_vertices: Optional[torch.Tensor] = None
lhand_joints: Optional[torch.Tensor] = None
rhand_joints: Optional[torch.Tensor] = None
face_joints: Optional[torch.Tensor] = None
lfoot_joints: Optional[torch.Tensor] = None
rfoot_joints: Optional[torch.Tensor] = None
class SMPL(_SMPL):
""" Extension of the official SMPL implementation to support more joints """
def __init__(
self,
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
*args,
**kwargs
):
super().__init__(
create_betas=create_betas,
create_global_orient=create_global_orient,
create_body_pose=create_body_pose,
create_transl=create_transl,
*args,
**kwargs
)
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer(
'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
)
self.joint_map = torch.tensor(joints, dtype=torch.long)
# self.ModelOutput = namedtuple('ModelOutput_', ModelOutput._fields + ('smpl_joints', 'joints_J19',))
# self.ModelOutput.__new__.__defaults__ = (None,) * len(self.ModelOutput._fields)
tpose_joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0))
self.register_buffer('tpose_joints', tpose_joints)
def forward(self, *args, **kwargs):
kwargs['get_skin'] = True
smpl_output = super().forward(*args, **kwargs)
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
# smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3]
vertices = smpl_output.vertices
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
smpl_joints = smpl_output.joints[:, :24]
joints = joints[:, self.joint_map, :] # [B, 49, 3]
joints_J24 = joints[:, -24:, :]
joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
output = ModelOutput(
vertices=vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
joints_J19=joints_J19,
smpl_joints=smpl_joints,
betas=smpl_output.betas,
full_pose=smpl_output.full_pose
)
return output
def get_global_rotation(
self,
global_orient: Optional[torch.Tensor] = None,
body_pose: Optional[torch.Tensor] = None,
**kwargs
):
'''
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. It is expected to be in rotation matrix
format. (default=None)
body_pose: torch.tensor, optional, shape BxJx3x3
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
Returns
-------
output: Global rotation matrix
'''
device, dtype = self.shapedirs.device, self.shapedirs.dtype
model_vars = [global_orient, body_pose]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
if global_orient is None:
global_orient = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
-1).contiguous()
if body_pose is None:
body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
batch_size, self.NUM_BODY_JOINTS, -1, -1
).contiguous()
# Concatenate all pose vectors
full_pose = torch.cat(
[global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3)],
dim=1
)
rot_mats = full_pose.view(batch_size, -1, 3, 3)
# Get the joints
# NxJx3 array
# joints = vertices2joints(self.J_regressor, self.v_template.unsqueeze(0).expand(batch_size, -1, -1))
# joints = torch.unsqueeze(joints, dim=-1)
joints = self.tpose_joints.expand(batch_size, -1, -1).unsqueeze(-1)
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, self.parents[1:]]
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
rel_joints.reshape(-1, 3,
1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, self.parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
global_rotmat = transforms[:, :, :3, :3]
# The last column of the transformations contains the posed joints
posed_joints = transforms[:, :, :3, 3]
return global_rotmat, posed_joints
class SMPLX(SMPLXLayer):
""" Extension of the official SMPLX implementation to support more functions """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_global_rotation(
self,
global_orient: Optional[torch.Tensor] = None,
body_pose: Optional[torch.Tensor] = None,
left_hand_pose: Optional[torch.Tensor] = None,
right_hand_pose: Optional[torch.Tensor] = None,
jaw_pose: Optional[torch.Tensor] = None,
leye_pose: Optional[torch.Tensor] = None,
reye_pose: Optional[torch.Tensor] = None,
**kwargs
):
'''
Forward pass for the SMPLX model
Parameters
----------
global_orient: torch.tensor, optional, shape Bx3x3
If given, ignore the member variable and use it as the global
rotation of the body. Useful if someone wishes to predicts this
with an external model. It is expected to be in rotation matrix
format. (default=None)
betas: torch.tensor, optional, shape BxN_b
If given, ignore the member variable `betas` and use it
instead. For example, it can used if shape parameters
`betas` are predicted from some external model.
(default=None)
expression: torch.tensor, optional, shape BxN_e
Expression coefficients.
For example, it can used if expression parameters
`expression` are predicted from some external model.
body_pose: torch.tensor, optional, shape BxJx3x3
If given, ignore the member variable `body_pose` and use it
instead. For example, it can used if someone predicts the
pose of the body joints are predicted from some external model.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
left_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the left hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
right_hand_pose: torch.tensor, optional, shape Bx15x3x3
If given, contains the pose of the right hand.
It should be a tensor that contains joint rotations in
rotation matrix format. (default=None)
jaw_pose: torch.tensor, optional, shape Bx3x3
Jaw pose. It should either joint rotations in
rotation matrix format.
transl: torch.tensor, optional, shape Bx3
Translation vector of the body.
For example, it can used if the translation
`transl` is predicted from some external model.
(default=None)
return_verts: bool, optional
Return the vertices. (default=True)
return_full_pose: bool, optional
Returns the full pose vector (default=False)
Returns
-------
output: ModelOutput
A data class that contains the posed vertices and joints
'''
device, dtype = self.shapedirs.device, self.shapedirs.dtype
model_vars = [global_orient, body_pose, left_hand_pose, right_hand_pose, jaw_pose]
batch_size = 1
for var in model_vars:
if var is None:
continue
batch_size = max(batch_size, len(var))
if global_orient is None:
global_orient = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
-1).contiguous()
if body_pose is None:
body_pose = torch.eye(3, device=device, dtype=dtype).view(1, 1, 3, 3).expand(
batch_size, self.NUM_BODY_JOINTS, -1, -1
).contiguous()
if left_hand_pose is None:
left_hand_pose = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, 15, -1,
-1).contiguous()
if right_hand_pose is None:
right_hand_pose = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3,
3).expand(batch_size, 15, -1,
-1).contiguous()
if jaw_pose is None:
jaw_pose = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
-1).contiguous()
if leye_pose is None:
leye_pose = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
-1).contiguous()
if reye_pose is None:
reye_pose = torch.eye(3, device=device,
dtype=dtype).view(1, 1, 3, 3).expand(batch_size, -1, -1,
-1).contiguous()
# Concatenate all pose vectors
full_pose = torch.cat(
[
global_orient.reshape(-1, 1, 3, 3),
body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3, 3),
jaw_pose.reshape(-1, 1, 3, 3),
leye_pose.reshape(-1, 1, 3, 3),
reye_pose.reshape(-1, 1, 3, 3),
left_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3),
right_hand_pose.reshape(-1, self.NUM_HAND_JOINTS, 3, 3)
],
dim=1
)
rot_mats = full_pose.view(batch_size, -1, 3, 3)
# Get the joints
# NxJx3 array
joints = vertices2joints(
self.J_regressor,
self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
)
joints = torch.unsqueeze(joints, dim=-1)
rel_joints = joints.clone()
rel_joints[:, 1:] -= joints[:, self.parents[1:]]
transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3),
rel_joints.reshape(-1, 3,
1)).reshape(-1, joints.shape[1], 4, 4)
transform_chain = [transforms_mat[:, 0]]
for i in range(1, self.parents.shape[0]):
# Subtract the joint location at the rest pose
# No need for rotation, since it's identity when at rest
curr_res = torch.matmul(transform_chain[self.parents[i]], transforms_mat[:, i])
transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)
global_rotmat = transforms[:, :, :3, :3]
# The last column of the transformations contains the posed joints
posed_joints = transforms[:, :, :3, 3]
return global_rotmat, posed_joints
class SMPLX_ALL(nn.Module):
""" Extension of the official SMPLX implementation to support more joints """
def __init__(self, batch_size=1, use_face_contour=True, all_gender=False, **kwargs):
super().__init__()
numBetas = 10
self.use_face_contour = use_face_contour
if all_gender:
self.genders = ['male', 'female', 'neutral']
else:
self.genders = ['neutral']
for gender in self.genders:
assert gender in ['male', 'female', 'neutral']
self.model_dict = nn.ModuleDict(
{
gender: SMPLX(
path_config.SMPL_MODEL_DIR,
gender=gender,
ext='npz',
num_betas=numBetas,
use_pca=False,
batch_size=batch_size,
use_face_contour=use_face_contour,
num_pca_comps=45,
**kwargs
)
for gender in self.genders
}
)
self.model_neutral = self.model_dict['neutral']
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(path_config.JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer(
'J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)
)
self.joint_map = torch.tensor(joints, dtype=torch.long)
# smplx_to_smpl.pkl, file source: https://smpl-x.is.tue.mpg.de
smplx_to_smpl = pickle.load(
open(os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb')
)
self.register_buffer(
'smplx2smpl', torch.tensor(smplx_to_smpl['matrix'][None], dtype=torch.float32)
)
smpl2limb_vert_faces = get_partial_smpl('smpl')
self.smpl2lhand = torch.from_numpy(smpl2limb_vert_faces['lhand']['vids']).long()
self.smpl2rhand = torch.from_numpy(smpl2limb_vert_faces['rhand']['vids']).long()
# left and right hand joint mapping
smplx2lhand_joints = [
constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.HAND_NAMES
]
smplx2rhand_joints = [
constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.HAND_NAMES
]
self.smplx2lh_joint_map = torch.tensor(smplx2lhand_joints, dtype=torch.long)
self.smplx2rh_joint_map = torch.tensor(smplx2rhand_joints, dtype=torch.long)
# left and right foot joint mapping
smplx2lfoot_joints = [
constants.SMPLX_JOINT_IDS['left_{}'.format(name)] for name in constants.FOOT_NAMES
]
smplx2rfoot_joints = [
constants.SMPLX_JOINT_IDS['right_{}'.format(name)] for name in constants.FOOT_NAMES
]
self.smplx2lf_joint_map = torch.tensor(smplx2lfoot_joints, dtype=torch.long)
self.smplx2rf_joint_map = torch.tensor(smplx2rfoot_joints, dtype=torch.long)
for g in self.genders:
J_template = torch.einsum(
'ji,ik->jk', [self.model_dict[g].J_regressor[:24], self.model_dict[g].v_template]
)
J_dirs = torch.einsum(
'ji,ikl->jkl', [self.model_dict[g].J_regressor[:24], self.model_dict[g].shapedirs]
)
self.register_buffer(f'{g}_J_template', J_template)
self.register_buffer(f'{g}_J_dirs', J_dirs)
def forward(self, *args, **kwargs):
batch_size = kwargs['body_pose'].shape[0]
kwargs['get_skin'] = True
if 'pose2rot' not in kwargs:
kwargs['pose2rot'] = True
if 'gender' not in kwargs:
kwargs['gender'] = 2 * torch.ones(batch_size).to(kwargs['body_pose'].device)
# pose for 55 joints: 1, 21, 15, 15, 1, 1, 1
pose_keys = [
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose', 'jaw_pose',
'leye_pose', 'reye_pose'
]
param_keys = ['betas'] + pose_keys
if kwargs['pose2rot']:
for key in pose_keys:
if key in kwargs:
# if key == 'left_hand_pose':
# kwargs[key] += self.model_neutral.left_hand_mean
# elif key == 'right_hand_pose':
# kwargs[key] += self.model_neutral.right_hand_mean
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
[batch_size, -1, 3, 3]
)
if kwargs['body_pose'].shape[1] == 23:
# remove hand pose in the body_pose
kwargs['body_pose'] = kwargs['body_pose'][:, :21]
gender_idx_list = []
smplx_vertices, smplx_joints = [], []
for gi, g in enumerate(['male', 'female', 'neutral']):
gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0])
if len(gender_idx) == 0:
continue
gender_idx_list.extend([int(idx) for idx in gender_idx])
gender_kwargs = {'get_skin': kwargs['get_skin'], 'pose2rot': kwargs['pose2rot']}
gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs})
gender_smplx_output = self.model_dict[g].forward(*args, **gender_kwargs)
smplx_vertices.append(gender_smplx_output.vertices)
smplx_joints.append(gender_smplx_output.joints)
idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))]
idx_rearrange = torch.tensor(idx_rearrange).long().to(kwargs['body_pose'].device)
smplx_vertices = torch.cat(smplx_vertices)[idx_rearrange]
smplx_joints = torch.cat(smplx_joints)[idx_rearrange]
# constants.HAND_NAMES
lhand_joints = smplx_joints[:, self.smplx2lh_joint_map]
rhand_joints = smplx_joints[:, self.smplx2rh_joint_map]
# constants.FACIAL_LANDMARKS
face_joints = smplx_joints[:, -68:] if self.use_face_contour else smplx_joints[:, -51:]
# constants.FOOT_NAMES
lfoot_joints = smplx_joints[:, self.smplx2lf_joint_map]
rfoot_joints = smplx_joints[:, self.smplx2rf_joint_map]
smpl_vertices = torch.bmm(self.smplx2smpl.expand(batch_size, -1, -1), smplx_vertices)
lhand_vertices = smpl_vertices[:, self.smpl2lhand]
rhand_vertices = smpl_vertices[:, self.smpl2rhand]
extra_joints = vertices2joints(self.J_regressor_extra, smpl_vertices)
# smpl_output.joints: [B, 45, 3] extra_joints: [B, 9, 3]
smplx_j45 = smplx_joints[:, constants.SMPLX2SMPL_J45]
joints = torch.cat([smplx_j45, extra_joints], dim=1)
smpl_joints = smplx_j45[:, :24]
joints = joints[:, self.joint_map, :] # [B, 49, 3]
joints_J24 = joints[:, -24:, :]
joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
output = ModelOutput(
vertices=smpl_vertices,
smplx_vertices=smplx_vertices,
lhand_vertices=lhand_vertices,
rhand_vertices=rhand_vertices,
# global_orient=smplx_output.global_orient,
# body_pose=smplx_output.body_pose,
joints=joints,
joints_J19=joints_J19,
smpl_joints=smpl_joints,
# betas=smplx_output.betas,
# full_pose=smplx_output.full_pose,
lhand_joints=lhand_joints,
rhand_joints=rhand_joints,
lfoot_joints=lfoot_joints,
rfoot_joints=rfoot_joints,
face_joints=face_joints,
)
return output
# def make_hand_regressor(self):
# # borrowed from https://github.com/mks0601/Hand4Whole_RELEASE/blob/main/common/utils/human_models.py
# regressor = self.model_neutral.J_regressor.numpy()
# vertex_num = self.model_neutral.J_regressor.shape[-1]
# lhand_regressor = np.concatenate((regressor[[20,37,38,39],:],
# np.eye(vertex_num)[5361,None],
# regressor[[25,26,27],:],
# np.eye(vertex_num)[4933,None],
# regressor[[28,29,30],:],
# np.eye(vertex_num)[5058,None],
# regressor[[34,35,36],:],
# np.eye(vertex_num)[5169,None],
# regressor[[31,32,33],:],
# np.eye(vertex_num)[5286,None]))
# rhand_regressor = np.concatenate((regressor[[21,52,53,54],:],
# np.eye(vertex_num)[8079,None],
# regressor[[40,41,42],:],
# np.eye(vertex_num)[7669,None],
# regressor[[43,44,45],:],
# np.eye(vertex_num)[7794,None],
# regressor[[49,50,51],:],
# np.eye(vertex_num)[7905,None],
# regressor[[46,47,48],:],
# np.eye(vertex_num)[8022,None]))
# return torch.from_numpy(lhand_regressor).float(), torch.from_numpy(rhand_regressor).float()
def get_tpose(self, betas=None, gender=None):
kwargs = {}
if betas is None:
betas = torch.zeros(1, 10).to(self.J_regressor_extra.device)
kwargs['betas'] = betas
batch_size = kwargs['betas'].shape[0]
device = kwargs['betas'].device
if gender is None:
kwargs['gender'] = 2 * torch.ones(batch_size).to(device)
else:
kwargs['gender'] = gender
param_keys = ['betas']
gender_idx_list = []
smplx_joints = []
for gi, g in enumerate(['male', 'female', 'neutral']):
gender_idx = ((kwargs['gender'] == gi).nonzero(as_tuple=True)[0])
if len(gender_idx) == 0:
continue
gender_idx_list.extend([int(idx) for idx in gender_idx])
gender_kwargs = {}
gender_kwargs.update({k: kwargs[k][gender_idx] for k in param_keys if k in kwargs})
J = getattr(self, f'{g}_J_template').unsqueeze(0) + blend_shapes(
gender_kwargs['betas'], getattr(self, f'{g}_J_dirs')
)
smplx_joints.append(J)
idx_rearrange = [gender_idx_list.index(i) for i in range(len(list(gender_idx_list)))]
idx_rearrange = torch.tensor(idx_rearrange).long().to(device)
smplx_joints = torch.cat(smplx_joints)[idx_rearrange]
return smplx_joints
class MANO(MANOLayer):
""" Extension of the official MANO implementation to support more joints """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
if 'pose2rot' not in kwargs:
kwargs['pose2rot'] = True
pose_keys = ['global_orient', 'right_hand_pose']
batch_size = kwargs['global_orient'].shape[0]
if kwargs['pose2rot']:
for key in pose_keys:
if key in kwargs:
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
[batch_size, -1, 3, 3]
)
kwargs['hand_pose'] = kwargs.pop('right_hand_pose')
mano_output = super().forward(*args, **kwargs)
th_verts = mano_output.vertices
th_jtr = mano_output.joints
# https://github.com/hassony2/manopth/blob/master/manopth/manolayer.py#L248-L260
# In addition to MANO reference joints we sample vertices on each finger
# to serve as finger tips
tips = th_verts[:, [745, 317, 445, 556, 673]]
th_jtr = torch.cat([th_jtr, tips], 1)
# Reorder joints to match visualization utilities
th_jtr = th_jtr[:,
[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
output = ModelOutput(
rhand_vertices=th_verts,
rhand_joints=th_jtr,
)
return output
class FLAME(FLAMELayer):
""" Extension of the official FLAME implementation to support more joints """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
if 'pose2rot' not in kwargs:
kwargs['pose2rot'] = True
pose_keys = ['global_orient', 'jaw_pose', 'leye_pose', 'reye_pose']
batch_size = kwargs['global_orient'].shape[0]
if kwargs['pose2rot']:
for key in pose_keys:
if key in kwargs:
kwargs[key] = batch_rodrigues(kwargs[key].contiguous().view(-1, 3)).view(
[batch_size, -1, 3, 3]
)
flame_output = super().forward(*args, **kwargs)
output = ModelOutput(
flame_vertices=flame_output.vertices,
face_joints=flame_output.joints[:, 5:],
)
return output
class SMPL_Family():
def __init__(self, model_type='smpl', *args, **kwargs):
if model_type == 'smpl':
self.model = SMPL(model_path=SMPL_MODEL_DIR, *args, **kwargs)
elif model_type == 'smplx':
self.model = SMPLX_ALL(*args, **kwargs)
elif model_type == 'mano':
self.model = MANO(
model_path=SMPL_MODEL_DIR, is_rhand=True, use_pca=False, *args, **kwargs
)
elif model_type == 'flame':
self.model = FLAME(model_path=SMPL_MODEL_DIR, use_face_contour=True, *args, **kwargs)
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
def get_tpose(self, *args, **kwargs):
return self.model.get_tpose(*args, **kwargs)
# def to(self, device):
# self.model.to(device)
# def cuda(self, device=None):
# if device is None:
# self.model.cuda()
# else:
# self.model.cuda(device)
def get_smpl_faces():
smpl = SMPL(model_path=SMPL_MODEL_DIR, batch_size=1)
return smpl.faces
def get_smplx_faces():
smplx = SMPLX(SMPL_MODEL_DIR, batch_size=1)
return smplx.faces
def get_mano_faces(hand_type='right'):
assert hand_type in ['right', 'left']
is_rhand = True if hand_type == 'right' else False
mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=is_rhand)
return mano.faces
def get_flame_faces():
flame = FLAME(SMPL_MODEL_DIR, batch_size=1)
return flame.faces
def get_model_faces(type='smpl'):
if type == 'smpl':
return get_smpl_faces()
elif type == 'smplx':
return get_smplx_faces()
elif type == 'mano':
return get_mano_faces()
elif type == 'flame':
return get_flame_faces()
def get_model_tpose(type='smpl'):
if type == 'smpl':
return get_smpl_tpose()
elif type == 'smplx':
return get_smplx_tpose()
elif type == 'mano':
return get_mano_tpose()
elif type == 'flame':
return get_flame_tpose()
def get_smpl_tpose():
smpl = SMPL(
create_betas=True,
create_global_orient=True,
create_body_pose=True,
model_path=SMPL_MODEL_DIR,
batch_size=1
)
vertices = smpl().vertices[0]
return vertices.detach()
def get_smpl_tpose_joint():
smpl = SMPL(
create_betas=True,
create_global_orient=True,
create_body_pose=True,
model_path=SMPL_MODEL_DIR,
batch_size=1
)
tpose_joint = smpl().smpl_joints[0]
return tpose_joint.detach()
def get_smplx_tpose():
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
vertices = smplx().vertices[0]
return vertices
def get_smplx_tpose_joint():
smplx = SMPLXLayer(SMPL_MODEL_DIR, batch_size=1)
tpose_joint = smplx().joints[0]
return tpose_joint
def get_mano_tpose():
mano = MANO(SMPL_MODEL_DIR, batch_size=1, is_rhand=True)
vertices = mano(global_orient=torch.zeros(1, 3),
right_hand_pose=torch.zeros(1, 15 * 3)).rhand_vertices[0]
return vertices
def get_flame_tpose():
flame = FLAME(SMPL_MODEL_DIR, batch_size=1)
vertices = flame(global_orient=torch.zeros(1, 3)).flame_vertices[0]
return vertices
def get_part_joints(smpl_joints):
batch_size = smpl_joints.shape[0]
# part_joints = torch.zeros().to(smpl_joints.device)
one_seg_pairs = [
(0, 1), (0, 2), (0, 3), (3, 6), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17)
]
two_seg_pairs = [(1, 4), (2, 5), (4, 7), (5, 8), (16, 18), (17, 19), (18, 20), (19, 21)]
one_seg_pairs.extend(two_seg_pairs)
single_joints = [(10), (11), (15), (22), (23)]
part_joints = []
for j_p in one_seg_pairs:
new_joint = torch.mean(smpl_joints[:, j_p], dim=1, keepdim=True)
part_joints.append(new_joint)
for j_p in single_joints:
part_joints.append(smpl_joints[:, j_p:j_p + 1])
part_joints = torch.cat(part_joints, dim=1)
return part_joints
def get_partial_smpl(body_model='smpl', device=torch.device('cuda')):
body_model_faces = get_model_faces(body_model)
body_model_num_verts = len(get_model_tpose(body_model))
part_vert_faces = {}
for part in ['lhand', 'rhand', 'face', 'arm', 'forearm', 'larm', 'rarm', 'lwrist', 'rwrist']:
part_vid_fname = '{}/{}_{}_vids.npz'.format(path_config.PARTIAL_MESH_DIR, body_model, part)
if os.path.exists(part_vid_fname):
part_vids = np.load(part_vid_fname)
part_vert_faces[part] = {'vids': part_vids['vids'], 'faces': part_vids['faces']}
else:
if part in ['lhand', 'rhand']:
with open(
os.path.join(SMPL_MODEL_DIR, 'model_transfer/MANO_SMPLX_vertex_ids.pkl'), 'rb'
) as json_file:
smplx_mano_id = pickle.load(json_file)
with open(
os.path.join(SMPL_MODEL_DIR, 'model_transfer/smplx_to_smpl.pkl'), 'rb'
) as json_file:
smplx_smpl_id = pickle.load(json_file)
smplx_tpose = get_smplx_tpose()
smpl_tpose = np.matmul(smplx_smpl_id['matrix'], smplx_tpose)
if part == 'lhand':
mano_vert = smplx_tpose[smplx_mano_id['left_hand']]
elif part == 'rhand':
mano_vert = smplx_tpose[smplx_mano_id['right_hand']]
smpl2mano_id = []
for vert in mano_vert:
v_diff = smpl_tpose - vert
v_diff = torch.sum(v_diff * v_diff, dim=1)
v_closest = torch.argmin(v_diff)
smpl2mano_id.append(int(v_closest))
smpl2mano_vids = np.array(smpl2mano_id).astype(np.longlong)
mano_faces = get_mano_faces(hand_type='right' if part == 'rhand' else 'left'
).astype(np.longlong)
np.savez(part_vid_fname, vids=smpl2mano_vids, faces=mano_faces)
part_vert_faces[part] = {'vids': smpl2mano_vids, 'faces': mano_faces}
elif part in ['face', 'arm', 'forearm', 'larm', 'rarm']:
with open(
os.path.join(SMPL_MODEL_DIR, '{}_vert_segmentation.json'.format(body_model)),
'rb'
) as json_file:
smplx_part_id = json.load(json_file)
# main_body_part = list(smplx_part_id.keys())
# print('main_body_part', main_body_part)
if part == 'face':
selected_body_part = ['head']
elif part == 'arm':
selected_body_part = [
'rightHand',
'leftArm',
'leftShoulder',
'rightShoulder',
'rightArm',
'leftHandIndex1',
'rightHandIndex1',
'leftForeArm',
'rightForeArm',
'leftHand',
]
# selected_body_part = ['rightHand', 'leftArm', 'rightArm', 'leftHandIndex1', 'rightHandIndex1', 'leftForeArm', 'rightForeArm', 'leftHand',]
elif part == 'forearm':
selected_body_part = [
'rightHand',
'leftHandIndex1',
'rightHandIndex1',
'leftForeArm',
'rightForeArm',
'leftHand',
]
elif part == 'arm_eval':
selected_body_part = ['leftArm', 'rightArm', 'leftForeArm', 'rightForeArm']
elif part == 'larm':
# selected_body_part = ['leftArm', 'leftForeArm']
selected_body_part = ['leftForeArm']
elif part == 'rarm':
# selected_body_part = ['rightArm', 'rightForeArm']
selected_body_part = ['rightForeArm']
part_body_idx = []
for k in selected_body_part:
part_body_idx.extend(smplx_part_id[k])
part_body_fid = []
for f_id, face in enumerate(body_model_faces):
if any(f in part_body_idx for f in face):
part_body_fid.append(f_id)
smpl2head_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong)
mesh_vid_raw = np.arange(body_model_num_verts)
head_vid_new = np.arange(len(smpl2head_vids))
mesh_vid_raw[smpl2head_vids] = head_vid_new
head_faces = body_model_faces[part_body_fid]
head_faces = mesh_vid_raw[head_faces].astype(np.longlong)
np.savez(part_vid_fname, vids=smpl2head_vids, faces=head_faces)
part_vert_faces[part] = {'vids': smpl2head_vids, 'faces': head_faces}
elif part in ['lwrist', 'rwrist']:
if body_model == 'smplx':
body_model_verts = get_smplx_tpose()
tpose_joint = get_smplx_tpose_joint()
elif body_model == 'smpl':
body_model_verts = get_smpl_tpose()
tpose_joint = get_smpl_tpose_joint()
wrist_joint = tpose_joint[20] if part == 'lwrist' else tpose_joint[21]
dist = 0.005
wrist_vids = []
for vid, vt in enumerate(body_model_verts):
v_j_dist = torch.sum((vt - wrist_joint)**2)
if v_j_dist < dist:
wrist_vids.append(vid)
wrist_vids = np.array(wrist_vids)
part_body_fid = []
for f_id, face in enumerate(body_model_faces):
if any(f in wrist_vids for f in face):
part_body_fid.append(f_id)
smpl2part_vids = np.unique(body_model_faces[part_body_fid]).astype(np.longlong)
mesh_vid_raw = np.arange(body_model_num_verts)
part_vid_new = np.arange(len(smpl2part_vids))
mesh_vid_raw[smpl2part_vids] = part_vid_new
part_faces = body_model_faces[part_body_fid]
part_faces = mesh_vid_raw[part_faces].astype(np.longlong)
np.savez(part_vid_fname, vids=smpl2part_vids, faces=part_faces)
part_vert_faces[part] = {'vids': smpl2part_vids, 'faces': part_faces}
# import trimesh
# mesh = trimesh.Trimesh(vertices=body_model_verts[smpl2part_vids], faces=part_faces, process=False)
# mesh.export(f'results/smplx_{part}.obj')
# mesh = trimesh.Trimesh(vertices=body_model_verts, faces=body_model_faces, process=False)
# mesh.export(f'results/smplx_model.obj')
return part_vert_faces