mart9992's picture
m
2cd560a
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from ..builder import MESH_MODELS
try:
from smplx import SMPL as SMPL_
has_smpl = True
except (ImportError, ModuleNotFoundError):
has_smpl = False
@MESH_MODELS.register_module()
class SMPL(nn.Module):
"""SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned
multi-person linear model''. This module is based on the smplx project
(https://github.com/vchoutas/smplx).
Args:
smpl_path (str): The path to the folder where the model weights are
stored.
joints_regressor (str): The path to the file where the joints
regressor weight are stored.
"""
def __init__(self, smpl_path, joints_regressor):
super().__init__()
assert has_smpl, 'Please install smplx to use SMPL.'
self.smpl_neutral = SMPL_(
model_path=smpl_path,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='neutral')
self.smpl_male = SMPL_(
model_path=smpl_path,
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='male')
self.smpl_female = SMPL_(
model_path=smpl_path,
create_betas=False,
create_global_orient=False,
create_body_pose=False,
create_transl=False,
gender='female')
joints_regressor = torch.tensor(
np.load(joints_regressor), dtype=torch.float)[None, ...]
self.register_buffer('joints_regressor', joints_regressor)
self.num_verts = self.smpl_neutral.get_num_verts()
self.num_joints = self.joints_regressor.shape[1]
def smpl_forward(self, model, **kwargs):
"""Apply a specific SMPL model with given model parameters.
Note:
B: batch size
V: number of vertices
K: number of joints
Returns:
outputs (dict): Dict with mesh vertices and joints.
- vertices: Tensor([B, V, 3]), mesh vertices
- joints: Tensor([B, K, 3]), 3d joints regressed
from mesh vertices.
"""
betas = kwargs['betas']
batch_size = betas.shape[0]
device = betas.device
output = {}
if batch_size == 0:
output['vertices'] = betas.new_zeros([0, self.num_verts, 3])
output['joints'] = betas.new_zeros([0, self.num_joints, 3])
else:
smpl_out = model(**kwargs)
output['vertices'] = smpl_out.vertices
output['joints'] = torch.matmul(
self.joints_regressor.to(device), output['vertices'])
return output
def get_faces(self):
"""Return mesh faces.
Note:
F: number of faces
Returns:
faces: np.ndarray([F, 3]), mesh faces
"""
return self.smpl_neutral.faces
def forward(self,
betas,
body_pose,
global_orient,
transl=None,
gender=None):
"""Forward function.
Note:
B: batch size
J: number of controllable joints of model, for smpl model J=23
K: number of joints
Args:
betas: Tensor([B, 10]), human body shape parameters of SMPL model.
body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
parameters of SMPL model. It should be axis-angle vector
([B, J*3]) or rotation matrix ([B, J, 3, 3)].
global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
of human body. It should be axis-angle vector ([B, 3]) or
rotation matrix ([B, 1, 3, 3)].
transl: Tensor([B, 3]), global translation of human body.
gender: Tensor([B]), gender parameters of human body. -1 for
neutral, 0 for male , 1 for female.
Returns:
outputs (dict): Dict with mesh vertices and joints.
- vertices: Tensor([B, V, 3]), mesh vertices
- joints: Tensor([B, K, 3]), 3d joints regressed from
mesh vertices.
"""
batch_size = betas.shape[0]
pose2rot = True if body_pose.dim() == 2 else False
if batch_size > 0 and gender is not None:
output = {
'vertices': betas.new_zeros([batch_size, self.num_verts, 3]),
'joints': betas.new_zeros([batch_size, self.num_joints, 3])
}
mask = gender < 0
_out = self.smpl_forward(
self.smpl_neutral,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']
mask = gender == 0
_out = self.smpl_forward(
self.smpl_male,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']
mask = gender == 1
_out = self.smpl_forward(
self.smpl_male,
betas=betas[mask],
body_pose=body_pose[mask],
global_orient=global_orient[mask],
transl=transl[mask] if transl is not None else None,
pose2rot=pose2rot)
output['vertices'][mask] = _out['vertices']
output['joints'][mask] = _out['joints']
else:
return self.smpl_forward(
self.smpl_neutral,
betas=betas,
body_pose=body_pose,
global_orient=global_orient,
transl=transl,
pose2rot=pose2rot)
return output