Spaces:
Running
on
L40S
Running
on
L40S
from collections import namedtuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .lbs import lbs, hybrik, rotmat_to_quat, quat_to_rotmat, rotation_matrix_to_angle_axis | |
try: | |
import cPickle as pk | |
except ImportError: | |
import pickle as pk | |
ModelOutput = namedtuple( | |
'ModelOutput', ['vertices', 'joints', 'joints_from_verts', 'rot_mats']) | |
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) | |
def to_tensor(array, dtype=torch.float32): | |
if 'torch.tensor' not in str(type(array)): | |
return torch.tensor(array, dtype=dtype) | |
class Struct(object): | |
def __init__(self, **kwargs): | |
for key, val in kwargs.items(): | |
setattr(self, key, val) | |
def to_np(array, dtype=np.float32): | |
if 'scipy.sparse' in str(type(array)): | |
array = array.todense() | |
return np.array(array, dtype=dtype) | |
class SMPL_layer(nn.Module): | |
NUM_JOINTS = 23 | |
NUM_BODY_JOINTS = 23 | |
NUM_BETAS = 10 | |
JOINT_NAMES = [ | |
'pelvis', | |
'left_hip', | |
'right_hip', # 2 | |
'spine1', | |
'left_knee', | |
'right_knee', # 5 | |
'spine2', | |
'left_ankle', | |
'right_ankle', # 8 | |
'spine3', | |
'left_foot', | |
'right_foot', # 11 | |
'neck', | |
'left_collar', | |
'right_collar', # 14 | |
'jaw', # 15 | |
'left_shoulder', | |
'right_shoulder', # 17 | |
'left_elbow', | |
'right_elbow', # 19 | |
'left_wrist', | |
'right_wrist', # 21 | |
'left_thumb', | |
'right_thumb', # 23 | |
'head', | |
'left_middle', | |
'right_middle', # 26 | |
'left_bigtoe', | |
'right_bigtoe' # 28 | |
] | |
LEAF_NAMES = [ | |
'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe' | |
] | |
root_idx_17 = 0 | |
root_idx_smpl = 0 | |
def __init__(self, | |
model_path, | |
h36m_jregressor, | |
gender='neutral', | |
dtype=torch.float32, | |
num_joints=29): | |
''' SMPL model layers | |
Parameters: | |
---------- | |
model_path: str | |
The path to the folder or to the file where the model | |
parameters are stored | |
gender: str, optional | |
Which gender to load | |
''' | |
super(SMPL_layer, self).__init__() | |
self.ROOT_IDX = self.JOINT_NAMES.index('pelvis') | |
self.LEAF_IDX = [ | |
self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES | |
] | |
self.SPINE3_IDX = 9 | |
with open(model_path, 'rb') as smpl_file: | |
self.smpl_data = Struct(**pk.load(smpl_file, encoding='latin1')) | |
self.gender = gender | |
self.dtype = dtype | |
self.faces = self.smpl_data.f | |
''' Register Buffer ''' | |
# Faces | |
self.register_buffer( | |
'faces_tensor', | |
to_tensor(to_np(self.smpl_data.f, dtype=np.int64), | |
dtype=torch.long)) | |
# The vertices of the template model, (6890, 3) | |
self.register_buffer( | |
'v_template', | |
to_tensor(to_np(self.smpl_data.v_template), dtype=dtype)) | |
# The shape components | |
# Shape blend shapes basis, (6890, 3, 10) | |
self.register_buffer( | |
'shapedirs', to_tensor(to_np(self.smpl_data.shapedirs), | |
dtype=dtype)) | |
# Pose blend shape basis: 6890 x 3 x 23*9, reshaped to 6890*3 x 23*9 | |
num_pose_basis = self.smpl_data.posedirs.shape[-1] | |
# 23*9 x 6890*3 | |
posedirs = np.reshape(self.smpl_data.posedirs, [-1, num_pose_basis]).T | |
self.register_buffer('posedirs', to_tensor(to_np(posedirs), | |
dtype=dtype)) | |
# Vertices to Joints location (23 + 1, 6890) | |
self.register_buffer( | |
'J_regressor', | |
to_tensor(to_np(self.smpl_data.J_regressor), dtype=dtype)) | |
# Vertices to Human3.6M Joints location (17, 6890) | |
self.register_buffer('J_regressor_h36m', | |
to_tensor(to_np(h36m_jregressor), dtype=dtype)) | |
self.num_joints = num_joints | |
# indices of parents for each joints | |
parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long) | |
parents[:(self.NUM_JOINTS + 1)] = to_tensor( | |
to_np(self.smpl_data.kintree_table[0])).long() | |
parents[0] = -1 | |
# extend kinematic tree | |
parents[24] = 15 | |
parents[25] = 22 | |
parents[26] = 23 | |
parents[27] = 10 | |
parents[28] = 11 | |
if parents.shape[0] > self.num_joints: | |
parents = parents[:24] | |
self.register_buffer('children_map', | |
self._parents_to_children(parents)) | |
# (24,) | |
self.register_buffer('parents', parents) | |
# (6890, 23 + 1) | |
self.register_buffer( | |
'lbs_weights', to_tensor(to_np(self.smpl_data.weights), | |
dtype=dtype)) | |
def _parents_to_children(self, parents): | |
children = torch.ones_like(parents) * -1 | |
for i in range(self.num_joints): | |
if children[parents[i]] < 0: | |
children[parents[i]] = i | |
for i in self.LEAF_IDX: | |
if i < children.shape[0]: | |
children[i] = -1 | |
children[self.SPINE3_IDX] = -3 | |
children[0] = 3 | |
children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck') | |
return children | |
def forward(self, | |
pose_axis_angle, | |
betas, | |
global_orient, | |
transl=None, | |
return_verts=True): | |
''' Forward pass for the SMPL model | |
Parameters | |
---------- | |
pose_axis_angle: torch.tensor, optional, shape Bx(J*3) | |
It should be a tensor that contains joint rotations in | |
axis-angle format. (default=None) | |
betas: torch.tensor, optional, shape Bx10 | |
It can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
global_orient: torch.tensor, optional, shape Bx3 | |
Global Orientations. | |
transl: torch.tensor, optional, shape Bx3 | |
Global Translations. | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
Returns | |
------- | |
''' | |
# batch_size = pose_axis_angle.shape[0] | |
# concate root orientation with thetas | |
if global_orient is not None: | |
full_pose = torch.cat([global_orient, pose_axis_angle], dim=1) | |
else: | |
full_pose = pose_axis_angle | |
# Translate thetas to rotation matrics | |
pose2rot = True | |
# vertices: (B, N, 3), joints: (B, K, 3) | |
vertices, joints, rot_mats, joints_from_verts_h36m = lbs( | |
betas, | |
full_pose, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.J_regressor_h36m, | |
self.parents, | |
self.lbs_weights, | |
pose2rot=pose2rot, | |
dtype=self.dtype) | |
if transl is not None: | |
# apply translations | |
joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
joints_from_verts_h36m += transl.unsqueeze(dim=1) | |
else: | |
vertices = vertices - \ | |
joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze( | |
1).detach() | |
joints = joints - \ | |
joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() | |
joints_from_verts_h36m = joints_from_verts_h36m - \ | |
joints_from_verts_h36m[:, self.root_idx_17, :].unsqueeze( | |
1).detach() | |
output = ModelOutput(vertices=vertices, | |
joints=joints, | |
rot_mats=rot_mats, | |
joints_from_verts=joints_from_verts_h36m) | |
return output | |
def hybrik(self, | |
pose_skeleton, | |
betas, | |
phis, | |
global_orient, | |
transl=None, | |
return_verts=True, | |
leaf_thetas=None): | |
''' Inverse pass for the SMPL model | |
Parameters | |
---------- | |
pose_skeleton: torch.tensor, optional, shape Bx(J*3) | |
It should be a tensor that contains joint locations in | |
(X, Y, Z) format. (default=None) | |
betas: torch.tensor, optional, shape Bx10 | |
It can used if shape parameters | |
`betas` are predicted from some external model. | |
(default=None) | |
global_orient: torch.tensor, optional, shape Bx3 | |
Global Orientations. | |
transl: torch.tensor, optional, shape Bx3 | |
Global Translations. | |
return_verts: bool, optional | |
Return the vertices. (default=True) | |
Returns | |
------- | |
''' | |
batch_size = pose_skeleton.shape[0] | |
if leaf_thetas is not None: | |
leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4) | |
leaf_thetas = quat_to_rotmat(leaf_thetas) | |
vertices, new_joints, rot_mats, joints_from_verts = hybrik( | |
betas, | |
global_orient, | |
pose_skeleton, | |
phis, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.J_regressor_h36m, | |
self.parents, | |
self.children_map, | |
self.lbs_weights, | |
dtype=self.dtype, | |
train=self.training, | |
leaf_thetas=leaf_thetas) | |
rot_mats = rot_mats.reshape(batch_size, 24, 3, 3) | |
# rot_aa = rotation_matrix_to_angle_axis(rot_mats) | |
# rot_mats = rotmat_to_quat(rot_mats).reshape(batch_size, 24 * 4) | |
if transl is not None: | |
new_joints += transl.unsqueeze(dim=1) | |
vertices += transl.unsqueeze(dim=1) | |
# joints_from_verts += transl.unsqueeze(dim=1) | |
else: | |
vertices = vertices - \ | |
joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() | |
new_joints = new_joints - \ | |
new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach() | |
# joints_from_verts = joints_from_verts - joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach() | |
output = ModelOutput(vertices=vertices, | |
joints=new_joints, | |
rot_mats=rot_mats, | |
joints_from_verts=joints_from_verts) | |
return output | |