fffiloni's picture
Migrated from GitHub
2252f3d verified
from collections import namedtuple
import numpy as np
import torch
import torch.nn as nn
from .lbs import lbs, hybrik, rotmat_to_quat, quat_to_rotmat, rotation_matrix_to_angle_axis
try:
import cPickle as pk
except ImportError:
import pickle as pk
ModelOutput = namedtuple(
'ModelOutput', ['vertices', 'joints', 'joints_from_verts', 'rot_mats'])
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields)
def to_tensor(array, dtype=torch.float32):
if 'torch.tensor' not in str(type(array)):
return torch.tensor(array, dtype=dtype)
class Struct(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
class SMPL_layer(nn.Module):
NUM_JOINTS = 23
NUM_BODY_JOINTS = 23
NUM_BETAS = 10
JOINT_NAMES = [
'pelvis',
'left_hip',
'right_hip', # 2
'spine1',
'left_knee',
'right_knee', # 5
'spine2',
'left_ankle',
'right_ankle', # 8
'spine3',
'left_foot',
'right_foot', # 11
'neck',
'left_collar',
'right_collar', # 14
'jaw', # 15
'left_shoulder',
'right_shoulder', # 17
'left_elbow',
'right_elbow', # 19
'left_wrist',
'right_wrist', # 21
'left_thumb',
'right_thumb', # 23
'head',
'left_middle',
'right_middle', # 26
'left_bigtoe',
'right_bigtoe' # 28
]
LEAF_NAMES = [
'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe'
]
root_idx_17 = 0
root_idx_smpl = 0
def __init__(self,
model_path,
h36m_jregressor,
gender='neutral',
dtype=torch.float32,
num_joints=29):
''' SMPL model layers
Parameters:
----------
model_path: str
The path to the folder or to the file where the model
parameters are stored
gender: str, optional
Which gender to load
'''
super(SMPL_layer, self).__init__()
self.ROOT_IDX = self.JOINT_NAMES.index('pelvis')
self.LEAF_IDX = [
self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES
]
self.SPINE3_IDX = 9
with open(model_path, 'rb') as smpl_file:
self.smpl_data = Struct(**pk.load(smpl_file, encoding='latin1'))
self.gender = gender
self.dtype = dtype
self.faces = self.smpl_data.f
''' Register Buffer '''
# Faces
self.register_buffer(
'faces_tensor',
to_tensor(to_np(self.smpl_data.f, dtype=np.int64),
dtype=torch.long))
# The vertices of the template model, (6890, 3)
self.register_buffer(
'v_template',
to_tensor(to_np(self.smpl_data.v_template), dtype=dtype))
# The shape components
# Shape blend shapes basis, (6890, 3, 10)
self.register_buffer(
'shapedirs', to_tensor(to_np(self.smpl_data.shapedirs),
dtype=dtype))
# Pose blend shape basis: 6890 x 3 x 23*9, reshaped to 6890*3 x 23*9
num_pose_basis = self.smpl_data.posedirs.shape[-1]
# 23*9 x 6890*3
posedirs = np.reshape(self.smpl_data.posedirs, [-1, num_pose_basis]).T
self.register_buffer('posedirs', to_tensor(to_np(posedirs),
dtype=dtype))
# Vertices to Joints location (23 + 1, 6890)
self.register_buffer(
'J_regressor',
to_tensor(to_np(self.smpl_data.J_regressor), dtype=dtype))
# Vertices to Human3.6M Joints location (17, 6890)
self.register_buffer('J_regressor_h36m',
to_tensor(to_np(h36m_jregressor), dtype=dtype))
self.num_joints = num_joints
# indices of parents for each joints
parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long)
parents[:(self.NUM_JOINTS + 1)] = to_tensor(
to_np(self.smpl_data.kintree_table[0])).long()
parents[0] = -1
# extend kinematic tree
parents[24] = 15
parents[25] = 22
parents[26] = 23
parents[27] = 10
parents[28] = 11
if parents.shape[0] > self.num_joints:
parents = parents[:24]
self.register_buffer('children_map',
self._parents_to_children(parents))
# (24,)
self.register_buffer('parents', parents)
# (6890, 23 + 1)
self.register_buffer(
'lbs_weights', to_tensor(to_np(self.smpl_data.weights),
dtype=dtype))
def _parents_to_children(self, parents):
children = torch.ones_like(parents) * -1
for i in range(self.num_joints):
if children[parents[i]] < 0:
children[parents[i]] = i
for i in self.LEAF_IDX:
if i < children.shape[0]:
children[i] = -1
children[self.SPINE3_IDX] = -3
children[0] = 3
children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck')
return children
def forward(self,
pose_axis_angle,
betas,
global_orient,
transl=None,
return_verts=True):
''' Forward pass for the SMPL model
Parameters
----------
pose_axis_angle: torch.tensor, optional, shape Bx(J*3)
It should be a tensor that contains joint rotations in
axis-angle format. (default=None)
betas: torch.tensor, optional, shape Bx10
It can used if shape parameters
`betas` are predicted from some external model.
(default=None)
global_orient: torch.tensor, optional, shape Bx3
Global Orientations.
transl: torch.tensor, optional, shape Bx3
Global Translations.
return_verts: bool, optional
Return the vertices. (default=True)
Returns
-------
'''
# batch_size = pose_axis_angle.shape[0]
# concate root orientation with thetas
if global_orient is not None:
full_pose = torch.cat([global_orient, pose_axis_angle], dim=1)
else:
full_pose = pose_axis_angle
# Translate thetas to rotation matrics
pose2rot = True
# vertices: (B, N, 3), joints: (B, K, 3)
vertices, joints, rot_mats, joints_from_verts_h36m = lbs(
betas,
full_pose,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.J_regressor_h36m,
self.parents,
self.lbs_weights,
pose2rot=pose2rot,
dtype=self.dtype)
if transl is not None:
# apply translations
joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
joints_from_verts_h36m += transl.unsqueeze(dim=1)
else:
vertices = vertices - \
joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze(
1).detach()
joints = joints - \
joints[:, self.root_idx_smpl, :].unsqueeze(1).detach()
joints_from_verts_h36m = joints_from_verts_h36m - \
joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze(
1).detach()
output = ModelOutput(vertices=vertices,
joints=joints,
rot_mats=rot_mats,
joints_from_verts=joints_from_verts_h36m)
return output
def hybrik(self,
pose_skeleton,
betas,
phis,
global_orient,
transl=None,
return_verts=True,
leaf_thetas=None):
''' Inverse pass for the SMPL model
Parameters
----------
pose_skeleton: torch.tensor, optional, shape Bx(J*3)
It should be a tensor that contains joint locations in
(X, Y, Z) format. (default=None)
betas: torch.tensor, optional, shape Bx10
It can used if shape parameters
`betas` are predicted from some external model.
(default=None)
global_orient: torch.tensor, optional, shape Bx3
Global Orientations.
transl: torch.tensor, optional, shape Bx3
Global Translations.
return_verts: bool, optional
Return the vertices. (default=True)
Returns
-------
'''
batch_size = pose_skeleton.shape[0]
if leaf_thetas is not None:
leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4)
leaf_thetas = quat_to_rotmat(leaf_thetas)
vertices, new_joints, rot_mats, joints_from_verts = hybrik(
betas,
global_orient,
pose_skeleton,
phis,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.J_regressor_h36m,
self.parents,
self.children_map,
self.lbs_weights,
dtype=self.dtype,
train=self.training,
leaf_thetas=leaf_thetas)
rot_mats = rot_mats.reshape(batch_size, 24, 3, 3)
# rot_aa = rotation_matrix_to_angle_axis(rot_mats)
# rot_mats = rotmat_to_quat(rot_mats).reshape(batch_size, 24 * 4)
if transl is not None:
new_joints += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
# joints_from_verts += transl.unsqueeze(dim=1)
else:
vertices = vertices - \
joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach()
new_joints = new_joints - \
new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach()
# joints_from_verts = joints_from_verts - joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach()
output = ModelOutput(vertices=vertices,
joints=new_joints,
rot_mats=rot_mats,
joints_from_verts=joints_from_verts)
return output