Spaces:
Configuration error
Configuration error
File size: 6,135 Bytes
1ba539f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import torch.nn as nn
from .lbs import lbs, batch_rodrigues
import os.path as osp
import pickle
import numpy as np
def to_tensor(array, dtype=torch.float32, device=torch.device('cpu')):
if 'torch.tensor' not in str(type(array)):
return torch.tensor(array, dtype=dtype).to(device)
else:
return array.to(device)
def to_np(array, dtype=np.float32):
if 'scipy.sparse' in str(type(array)):
array = array.todense()
return np.array(array, dtype=dtype)
class SMPLlayer(nn.Module):
def __init__(self,
model_path,
gender='neutral',
device=None,
regressor_path=None) -> None:
super(SMPLlayer, self).__init__()
dtype = torch.float32
self.dtype = dtype
self.device = device
# create the SMPL model
if osp.isdir(model_path):
model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
smpl_path = osp.join(model_path, model_fn)
else:
smpl_path = model_path
assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
smpl_path)
with open(smpl_path, 'rb') as smpl_file:
data = pickle.load(smpl_file, encoding='latin1')
self.faces = data['f']
self.register_buffer(
'faces_tensor',
to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long))
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
num_pose_basis = data['posedirs'].shape[-1]
# 207 x 20670
posedirs = data['posedirs']
data['posedirs'] = np.reshape(data['posedirs'], [-1, num_pose_basis]).T
for key in [
'J_regressor', 'v_template', 'weights', 'posedirs', 'shapedirs'
]:
val = to_tensor(to_np(data[key]), dtype=dtype)
self.register_buffer(key, val)
# indices of parents for each joints
parents = to_tensor(to_np(data['kintree_table'][0])).long()
parents[0] = -1
self.register_buffer('parents', parents)
# joints regressor
if regressor_path is not None:
X_regressor = to_tensor(np.load(regressor_path))
X_regressor = torch.cat((self.J_regressor, X_regressor), dim=0)
j_J_regressor = torch.zeros(24,
X_regressor.shape[0],
device=device)
for i in range(24):
j_J_regressor[i, i] = 1
j_v_template = X_regressor @ self.v_template
#
j_shapedirs = torch.einsum('vij,kv->kij',
[self.shapedirs, X_regressor])
# (25, 24)
j_weights = X_regressor @ self.weights
j_posedirs = torch.einsum(
'ab, bde->ade',
[X_regressor, torch.Tensor(posedirs)]).numpy()
j_posedirs = np.reshape(j_posedirs, [-1, num_pose_basis]).T
j_posedirs = to_tensor(j_posedirs)
self.register_buffer('j_posedirs', j_posedirs)
self.register_buffer('j_shapedirs', j_shapedirs)
self.register_buffer('j_weights', j_weights)
self.register_buffer('j_v_template', j_v_template)
self.register_buffer('j_J_regressor', j_J_regressor)
def forward(self,
poses,
shapes,
Rh=None,
Th=None,
return_verts=True,
return_tensor=True,
scale=1,
new_params=False,
**kwargs):
""" Forward pass for SMPL model
Args:
poses (n, 72)
shapes (n, 10)
Rh (n, 3): global orientation
Th (n, 3): global translation
return_verts (bool, optional): if True return (6890, 3). Defaults to False.
"""
if 'torch' not in str(type(poses)):
dtype, device = self.dtype, self.device
poses = to_tensor(poses, dtype, device)
shapes = to_tensor(shapes, dtype, device)
Rh = to_tensor(Rh, dtype, device)
Th = to_tensor(Th, dtype, device)
bn = poses.shape[0]
if Rh is None:
Rh = torch.zeros(bn, 3, device=poses.device)
rot = batch_rodrigues(Rh)
transl = Th.unsqueeze(dim=1)
if shapes.shape[0] < bn:
shapes = shapes.expand(bn, -1)
if return_verts:
vertices, joints = lbs(shapes,
poses,
self.v_template,
self.shapedirs,
self.posedirs,
self.J_regressor,
self.parents,
self.weights,
pose2rot=True,
new_params=new_params,
dtype=self.dtype)
else:
vertices, joints = lbs(shapes,
poses,
self.j_v_template,
self.j_shapedirs,
self.j_posedirs,
self.j_J_regressor,
self.parents,
self.j_weights,
pose2rot=True,
new_params=new_params,
dtype=self.dtype)
vertices = vertices[:, 24:, :]
# transl = transl + joints[:, :1] * scale - torch.matmul(joints[:, :1],
# rot.permute(0, 2, 1)) * scale
vertices = torch.matmul(vertices, rot.transpose(1, 2)) * scale + transl
# vertices = vertices * scale + transl
if not return_tensor:
vertices = vertices.detach().cpu().numpy()
transl = transl.detach().cpu().numpy()
return vertices[0]
|