Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
from .lbs import lbs, batch_rodrigues | |
import os.path as osp | |
import pickle | |
import numpy as np | |
def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')): | |
if 'torch.tensor' not in str(type(array)): | |
return torch.tensor(array, dtype=dtype).to(device) | |
else: | |
return array.to(device) | |
def to_np(array, dtype=np.float32): | |
if 'scipy.sparse' in str(type(array)): | |
array = array.todense() | |
return np.array(array, dtype=dtype) | |
class SMPLlayer(nn.Module): | |
def __init__(self, | |
model_path, | |
gender='neutral', | |
device=None, | |
regressor_path=None) -> None: | |
super(SMPLlayer, self).__init__() | |
dtype = torch.float32 | |
self.dtype = dtype | |
self.device = device | |
# create the SMPL model | |
if osp.isdir(model_path): | |
model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') | |
smpl_path = osp.join(model_path, model_fn) | |
else: | |
smpl_path = model_path | |
assert osp.exists(smpl_path), 'Path {} does not exist!'.format( | |
smpl_path) | |
with open(smpl_path, 'rb') as smpl_file: | |
data = pickle.load(smpl_file, encoding='latin1') | |
self.faces = data['f'] | |
self.register_buffer( | |
'faces_tensor', | |
to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long)) | |
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 | |
num_pose_basis = data['posedirs'].shape[-1] | |
# 207 x 20670 | |
posedirs = data['posedirs'] | |
data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T | |
for key in [ | |
'J_regressor', 'v_template', 'weights', 'posedirs', 'shapedirs' | |
]: | |
val = to_tensor(to_np(data[key]), dtype=dtype) | |
self.register_buffer(key, val) | |
# indices of parents for each joints | |
parents = to_tensor(to_np(data['kintree_table'][0])).long() | |
parents[0] = -1 | |
self.register_buffer('parents', parents) | |
# joints regressor | |
if regressor_path is not None: | |
X_regressor = to_tensor(np.load(regressor_path)) | |
X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0) | |
j_J_regressor = torch.zeros(24, | |
X_regressor.shape[0], | |
device=device) | |
for i in range(24): | |
j_J_regressor[i, i] = 1 | |
j_v_template = X_regressor @ self.v_template | |
# | |
j_shapedirs = torch.einsum('vij,kv->kij', | |
[self.shapedirs, X_regressor]) | |
# (25, 24) | |
j_weights = X_regressor @ self.weights | |
j_posedirs = torch.einsum( | |
'ab, bde->ade', | |
[X_regressor, torch.Tensor(posedirs)]).numpy() | |
j_posedirs = np.reshape(j_posedirs, [-1, num_pose_basis]).T | |
j_posedirs = to_tensor(j_posedirs) | |
self.register_buffer('j_posedirs', j_posedirs) | |
self.register_buffer('j_shapedirs', j_shapedirs) | |
self.register_buffer('j_weights', j_weights) | |
self.register_buffer('j_v_template', j_v_template) | |
self.register_buffer('j_J_regressor', j_J_regressor) | |
def forward(self, | |
poses, | |
shapes, | |
Rh=None, | |
Th=None, | |
return_verts=True, | |
return_tensor=True, | |
scale=1, | |
new_params=False, | |
**kwargs): | |
""" Forward pass for SMPL model | |
Args: | |
poses (n, 72) | |
shapes (n, 10) | |
Rh (n, 3): global orientation | |
Th (n, 3): global translation | |
return_verts (bool, optional): if True return (6890, 3). Defaults to False. | |
""" | |
if 'torch' not in str(type(poses)): | |
dtype, device = self.dtype, self.device | |
poses = to_tensor(poses, dtype, device) | |
shapes = to_tensor(shapes, dtype, device) | |
Rh = to_tensor(Rh, dtype, device) | |
Th = to_tensor(Th, dtype, device) | |
bn = poses.shape[0] | |
if Rh is None: | |
Rh = torch.zeros(bn, 3, device=poses.device) | |
rot = batch_rodrigues(Rh) | |
transl = Th.unsqueeze(dim=1) | |
if shapes.shape[0] < bn: | |
shapes = shapes.expand(bn, -1) | |
if return_verts: | |
vertices, joints = lbs(shapes, | |
poses, | |
self.v_template, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.weights, | |
pose2rot=True, | |
new_params=new_params, | |
dtype=self.dtype) | |
else: | |
vertices, joints = lbs(shapes, | |
poses, | |
self.j_v_template, | |
self.j_shapedirs, | |
self.j_posedirs, | |
self.j_J_regressor, | |
self.parents, | |
self.j_weights, | |
pose2rot=True, | |
new_params=new_params, | |
dtype=self.dtype) | |
vertices = vertices[:, 24:, :] | |
# transl = transl + joints[:, :1] * scale - torch.matmul(joints[:, :1], | |
# rot.permute(0, 2, 1)) * scale | |
vertices = torch.matmul(vertices, rot.transpose(1, 2)) * scale + transl | |
# vertices = vertices * scale + transl | |
if not return_tensor: | |
vertices = vertices.detach().cpu().numpy() | |
transl = transl.detach().cpu().numpy() | |
return vertices[0] | |