|
""" |
|
original from https://github.com/vchoutas/smplx |
|
modified by Vassilis and Yao |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import pickle |
|
|
|
from .lbs import ( |
|
Struct, |
|
to_tensor, |
|
to_np, |
|
lbs, |
|
vertices2landmarks, |
|
JointsFromVerticesSelector, |
|
find_dynamic_lmk_idx_and_bcoords, |
|
) |
|
|
|
|
|
J14_NAMES = [ |
|
"right_ankle", |
|
"right_knee", |
|
"right_hip", |
|
"left_hip", |
|
"left_knee", |
|
"left_ankle", |
|
"right_wrist", |
|
"right_elbow", |
|
"right_shoulder", |
|
"left_shoulder", |
|
"left_elbow", |
|
"left_wrist", |
|
"neck", |
|
"head", |
|
] |
|
SMPLX_names = [ |
|
"pelvis", |
|
"left_hip", |
|
"right_hip", |
|
"spine1", |
|
"left_knee", |
|
"right_knee", |
|
"spine2", |
|
"left_ankle", |
|
"right_ankle", |
|
"spine3", |
|
"left_foot", |
|
"right_foot", |
|
"neck", |
|
"left_collar", |
|
"right_collar", |
|
"head", |
|
"left_shoulder", |
|
"right_shoulder", |
|
"left_elbow", |
|
"right_elbow", |
|
"left_wrist", |
|
"right_wrist", |
|
"jaw", |
|
"left_eye_smplx", |
|
"right_eye_smplx", |
|
"left_index1", |
|
"left_index2", |
|
"left_index3", |
|
"left_middle1", |
|
"left_middle2", |
|
"left_middle3", |
|
"left_pinky1", |
|
"left_pinky2", |
|
"left_pinky3", |
|
"left_ring1", |
|
"left_ring2", |
|
"left_ring3", |
|
"left_thumb1", |
|
"left_thumb2", |
|
"left_thumb3", |
|
"right_index1", |
|
"right_index2", |
|
"right_index3", |
|
"right_middle1", |
|
"right_middle2", |
|
"right_middle3", |
|
"right_pinky1", |
|
"right_pinky2", |
|
"right_pinky3", |
|
"right_ring1", |
|
"right_ring2", |
|
"right_ring3", |
|
"right_thumb1", |
|
"right_thumb2", |
|
"right_thumb3", |
|
"right_eye_brow1", |
|
"right_eye_brow2", |
|
"right_eye_brow3", |
|
"right_eye_brow4", |
|
"right_eye_brow5", |
|
"left_eye_brow5", |
|
"left_eye_brow4", |
|
"left_eye_brow3", |
|
"left_eye_brow2", |
|
"left_eye_brow1", |
|
"nose1", |
|
"nose2", |
|
"nose3", |
|
"nose4", |
|
"right_nose_2", |
|
"right_nose_1", |
|
"nose_middle", |
|
"left_nose_1", |
|
"left_nose_2", |
|
"right_eye1", |
|
"right_eye2", |
|
"right_eye3", |
|
"right_eye4", |
|
"right_eye5", |
|
"right_eye6", |
|
"left_eye4", |
|
"left_eye3", |
|
"left_eye2", |
|
"left_eye1", |
|
"left_eye6", |
|
"left_eye5", |
|
"right_mouth_1", |
|
"right_mouth_2", |
|
"right_mouth_3", |
|
"mouth_top", |
|
"left_mouth_3", |
|
"left_mouth_2", |
|
"left_mouth_1", |
|
"left_mouth_5", |
|
"left_mouth_4", |
|
"mouth_bottom", |
|
"right_mouth_4", |
|
"right_mouth_5", |
|
"right_lip_1", |
|
"right_lip_2", |
|
"lip_top", |
|
"left_lip_2", |
|
"left_lip_1", |
|
"left_lip_3", |
|
"lip_bottom", |
|
"right_lip_3", |
|
"right_contour_1", |
|
"right_contour_2", |
|
"right_contour_3", |
|
"right_contour_4", |
|
"right_contour_5", |
|
"right_contour_6", |
|
"right_contour_7", |
|
"right_contour_8", |
|
"contour_middle", |
|
"left_contour_8", |
|
"left_contour_7", |
|
"left_contour_6", |
|
"left_contour_5", |
|
"left_contour_4", |
|
"left_contour_3", |
|
"left_contour_2", |
|
"left_contour_1", |
|
"head_top", |
|
"left_big_toe", |
|
"left_ear", |
|
"left_eye", |
|
"left_heel", |
|
"left_index", |
|
"left_middle", |
|
"left_pinky", |
|
"left_ring", |
|
"left_small_toe", |
|
"left_thumb", |
|
"nose", |
|
"right_big_toe", |
|
"right_ear", |
|
"right_eye", |
|
"right_heel", |
|
"right_index", |
|
"right_middle", |
|
"right_pinky", |
|
"right_ring", |
|
"right_small_toe", |
|
"right_thumb", |
|
] |
|
extra_names = [ |
|
"head_top", |
|
"left_big_toe", |
|
"left_ear", |
|
"left_eye", |
|
"left_heel", |
|
"left_index", |
|
"left_middle", |
|
"left_pinky", |
|
"left_ring", |
|
"left_small_toe", |
|
"left_thumb", |
|
"nose", |
|
"right_big_toe", |
|
"right_ear", |
|
"right_eye", |
|
"right_heel", |
|
"right_index", |
|
"right_middle", |
|
"right_pinky", |
|
"right_ring", |
|
"right_small_toe", |
|
"right_thumb", |
|
] |
|
SMPLX_names += extra_names |
|
|
|
part_indices = {} |
|
part_indices["body"] = np.array([ |
|
0, |
|
1, |
|
2, |
|
3, |
|
4, |
|
5, |
|
6, |
|
7, |
|
8, |
|
9, |
|
10, |
|
11, |
|
12, |
|
13, |
|
14, |
|
15, |
|
16, |
|
17, |
|
18, |
|
19, |
|
20, |
|
21, |
|
22, |
|
23, |
|
24, |
|
123, |
|
124, |
|
125, |
|
126, |
|
127, |
|
132, |
|
134, |
|
135, |
|
136, |
|
137, |
|
138, |
|
143, |
|
]) |
|
part_indices["torso"] = np.array([ |
|
0, |
|
1, |
|
2, |
|
3, |
|
6, |
|
9, |
|
12, |
|
13, |
|
14, |
|
15, |
|
16, |
|
17, |
|
18, |
|
19, |
|
22, |
|
23, |
|
24, |
|
55, |
|
56, |
|
57, |
|
58, |
|
59, |
|
76, |
|
77, |
|
78, |
|
79, |
|
80, |
|
81, |
|
82, |
|
83, |
|
84, |
|
85, |
|
86, |
|
87, |
|
88, |
|
89, |
|
90, |
|
91, |
|
92, |
|
93, |
|
94, |
|
95, |
|
96, |
|
97, |
|
98, |
|
99, |
|
100, |
|
101, |
|
102, |
|
103, |
|
104, |
|
105, |
|
106, |
|
107, |
|
108, |
|
109, |
|
110, |
|
111, |
|
112, |
|
113, |
|
114, |
|
115, |
|
116, |
|
117, |
|
118, |
|
119, |
|
120, |
|
121, |
|
122, |
|
123, |
|
124, |
|
125, |
|
126, |
|
127, |
|
128, |
|
129, |
|
130, |
|
131, |
|
132, |
|
133, |
|
134, |
|
135, |
|
136, |
|
137, |
|
138, |
|
139, |
|
140, |
|
141, |
|
142, |
|
143, |
|
144, |
|
]) |
|
part_indices["head"] = np.array([ |
|
12, |
|
15, |
|
22, |
|
23, |
|
24, |
|
55, |
|
56, |
|
57, |
|
58, |
|
59, |
|
60, |
|
61, |
|
62, |
|
63, |
|
64, |
|
65, |
|
66, |
|
67, |
|
68, |
|
69, |
|
70, |
|
71, |
|
72, |
|
73, |
|
74, |
|
75, |
|
76, |
|
77, |
|
78, |
|
79, |
|
80, |
|
81, |
|
82, |
|
83, |
|
84, |
|
85, |
|
86, |
|
87, |
|
88, |
|
89, |
|
90, |
|
91, |
|
92, |
|
93, |
|
94, |
|
95, |
|
96, |
|
97, |
|
98, |
|
99, |
|
100, |
|
101, |
|
102, |
|
103, |
|
104, |
|
105, |
|
106, |
|
107, |
|
108, |
|
109, |
|
110, |
|
111, |
|
112, |
|
113, |
|
114, |
|
115, |
|
116, |
|
117, |
|
118, |
|
119, |
|
120, |
|
121, |
|
122, |
|
123, |
|
125, |
|
126, |
|
134, |
|
136, |
|
137, |
|
]) |
|
part_indices["face"] = np.array([ |
|
55, |
|
56, |
|
57, |
|
58, |
|
59, |
|
60, |
|
61, |
|
62, |
|
63, |
|
64, |
|
65, |
|
66, |
|
67, |
|
68, |
|
69, |
|
70, |
|
71, |
|
72, |
|
73, |
|
74, |
|
75, |
|
76, |
|
77, |
|
78, |
|
79, |
|
80, |
|
81, |
|
82, |
|
83, |
|
84, |
|
85, |
|
86, |
|
87, |
|
88, |
|
89, |
|
90, |
|
91, |
|
92, |
|
93, |
|
94, |
|
95, |
|
96, |
|
97, |
|
98, |
|
99, |
|
100, |
|
101, |
|
102, |
|
103, |
|
104, |
|
105, |
|
106, |
|
107, |
|
108, |
|
109, |
|
110, |
|
111, |
|
112, |
|
113, |
|
114, |
|
115, |
|
116, |
|
117, |
|
118, |
|
119, |
|
120, |
|
121, |
|
122, |
|
]) |
|
part_indices["upper"] = np.array([ |
|
12, |
|
13, |
|
14, |
|
55, |
|
56, |
|
57, |
|
58, |
|
59, |
|
60, |
|
61, |
|
62, |
|
63, |
|
64, |
|
65, |
|
66, |
|
67, |
|
68, |
|
69, |
|
70, |
|
71, |
|
72, |
|
73, |
|
74, |
|
75, |
|
76, |
|
77, |
|
78, |
|
79, |
|
80, |
|
81, |
|
82, |
|
83, |
|
84, |
|
85, |
|
86, |
|
87, |
|
88, |
|
89, |
|
90, |
|
91, |
|
92, |
|
93, |
|
94, |
|
95, |
|
96, |
|
97, |
|
98, |
|
99, |
|
100, |
|
101, |
|
102, |
|
103, |
|
104, |
|
105, |
|
106, |
|
107, |
|
108, |
|
109, |
|
110, |
|
111, |
|
112, |
|
113, |
|
114, |
|
115, |
|
116, |
|
117, |
|
118, |
|
119, |
|
120, |
|
121, |
|
122, |
|
]) |
|
part_indices["hand"] = np.array([ |
|
20, |
|
21, |
|
25, |
|
26, |
|
27, |
|
28, |
|
29, |
|
30, |
|
31, |
|
32, |
|
33, |
|
34, |
|
35, |
|
36, |
|
37, |
|
38, |
|
39, |
|
40, |
|
41, |
|
42, |
|
43, |
|
44, |
|
45, |
|
46, |
|
47, |
|
48, |
|
49, |
|
50, |
|
51, |
|
52, |
|
53, |
|
54, |
|
128, |
|
129, |
|
130, |
|
131, |
|
133, |
|
139, |
|
140, |
|
141, |
|
142, |
|
144, |
|
]) |
|
part_indices["left_hand"] = np.array([ |
|
20, |
|
25, |
|
26, |
|
27, |
|
28, |
|
29, |
|
30, |
|
31, |
|
32, |
|
33, |
|
34, |
|
35, |
|
36, |
|
37, |
|
38, |
|
39, |
|
128, |
|
129, |
|
130, |
|
131, |
|
133, |
|
]) |
|
part_indices["right_hand"] = np.array([ |
|
21, |
|
40, |
|
41, |
|
42, |
|
43, |
|
44, |
|
45, |
|
46, |
|
47, |
|
48, |
|
49, |
|
50, |
|
51, |
|
52, |
|
53, |
|
54, |
|
139, |
|
140, |
|
141, |
|
142, |
|
144, |
|
]) |
|
|
|
head_kin_chain = [15, 12, 9, 6, 3, 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SMPLX(nn.Module): |
|
""" |
|
Given smplx parameters, this class generates a differentiable SMPLX function |
|
which outputs a mesh and 3D joints |
|
""" |
|
|
|
def __init__(self, config): |
|
super(SMPLX, self).__init__() |
|
|
|
ss = np.load(config.smplx_model_path, allow_pickle=True) |
|
smplx_model = Struct(**ss) |
|
|
|
self.dtype = torch.float32 |
|
self.register_buffer( |
|
"faces_tensor", |
|
to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), |
|
) |
|
|
|
self.register_buffer( |
|
"v_template", |
|
to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) |
|
|
|
|
|
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) |
|
shapedirs = torch.cat( |
|
[ |
|
shapedirs[:, :, :config.n_shape], |
|
shapedirs[:, :, 300:300 + config.n_exp], |
|
], |
|
2, |
|
) |
|
self.register_buffer("shapedirs", shapedirs) |
|
|
|
num_pose_basis = smplx_model.posedirs.shape[-1] |
|
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T |
|
self.register_buffer("posedirs", |
|
to_tensor(to_np(posedirs), dtype=self.dtype)) |
|
self.register_buffer( |
|
"J_regressor", |
|
to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) |
|
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() |
|
parents[0] = -1 |
|
self.register_buffer("parents", parents) |
|
self.register_buffer( |
|
"lbs_weights", |
|
to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) |
|
|
|
self.register_buffer( |
|
"lmk_faces_idx", |
|
torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) |
|
self.register_buffer( |
|
"lmk_bary_coords", |
|
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), |
|
) |
|
self.register_buffer( |
|
"dynamic_lmk_faces_idx", |
|
torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), |
|
) |
|
self.register_buffer( |
|
"dynamic_lmk_bary_coords", |
|
torch.tensor(smplx_model.dynamic_lmk_bary_coords, |
|
dtype=self.dtype), |
|
) |
|
|
|
self.register_buffer("head_kin_chain", |
|
torch.tensor(head_kin_chain, dtype=torch.long)) |
|
|
|
|
|
|
|
self.register_buffer( |
|
"shape_params", |
|
nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), |
|
requires_grad=False), |
|
) |
|
self.register_buffer( |
|
"expression_params", |
|
nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), |
|
requires_grad=False), |
|
) |
|
|
|
self.register_buffer( |
|
"global_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"head_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"neck_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"jaw_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"eye_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"body_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"left_hand_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
self.register_buffer( |
|
"right_hand_pose", |
|
nn.Parameter( |
|
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), |
|
requires_grad=False, |
|
), |
|
) |
|
|
|
if config.extra_joint_path: |
|
self.extra_joint_selector = JointsFromVerticesSelector( |
|
fname=config.extra_joint_path) |
|
self.use_joint_regressor = True |
|
self.keypoint_names = SMPLX_names |
|
if self.use_joint_regressor: |
|
with open(config.j14_regressor_path, "rb") as f: |
|
j14_regressor = pickle.load(f, encoding="latin1") |
|
source = [] |
|
target = [] |
|
for idx, name in enumerate(self.keypoint_names): |
|
if name in J14_NAMES: |
|
source.append(idx) |
|
target.append(J14_NAMES.index(name)) |
|
source = np.asarray(source) |
|
target = np.asarray(target) |
|
self.register_buffer("source_idxs", torch.from_numpy(source)) |
|
self.register_buffer("target_idxs", torch.from_numpy(target)) |
|
joint_regressor = torch.from_numpy(j14_regressor).to( |
|
dtype=torch.float32) |
|
self.register_buffer("extra_joint_regressor", joint_regressor) |
|
self.part_indices = part_indices |
|
|
|
def forward( |
|
self, |
|
shape_params=None, |
|
expression_params=None, |
|
global_pose=None, |
|
body_pose=None, |
|
jaw_pose=None, |
|
eye_pose=None, |
|
left_hand_pose=None, |
|
right_hand_pose=None, |
|
): |
|
""" |
|
Args: |
|
shape_params: [N, number of shape parameters] |
|
expression_params: [N, number of expression parameters] |
|
global_pose: pelvis pose, [N, 1, 3, 3] |
|
body_pose: [N, 21, 3, 3] |
|
jaw_pose: [N, 1, 3, 3] |
|
eye_pose: [N, 2, 3, 3] |
|
left_hand_pose: [N, 15, 3, 3] |
|
right_hand_pose: [N, 15, 3, 3] |
|
Returns: |
|
vertices: [N, number of vertices, 3] |
|
landmarks: [N, number of landmarks (68 face keypoints), 3] |
|
joints: [N, number of smplx joints (145), 3] |
|
""" |
|
if shape_params is None: |
|
batch_size = global_pose.shape[0] |
|
shape_params = self.shape_params.expand(batch_size, -1) |
|
else: |
|
batch_size = shape_params.shape[0] |
|
if expression_params is None: |
|
expression_params = self.expression_params.expand(batch_size, -1) |
|
if global_pose is None: |
|
global_pose = self.global_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
if body_pose is None: |
|
body_pose = self.body_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
if jaw_pose is None: |
|
jaw_pose = self.jaw_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
if eye_pose is None: |
|
eye_pose = self.eye_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
if left_hand_pose is None: |
|
left_hand_pose = self.left_hand_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
if right_hand_pose is None: |
|
right_hand_pose = self.right_hand_pose.unsqueeze(0).expand( |
|
batch_size, -1, -1, -1) |
|
|
|
shape_components = torch.cat([shape_params, expression_params], dim=1) |
|
full_pose = torch.cat( |
|
[ |
|
global_pose, |
|
body_pose, |
|
jaw_pose, |
|
eye_pose, |
|
left_hand_pose, |
|
right_hand_pose, |
|
], |
|
dim=1, |
|
) |
|
template_vertices = self.v_template.unsqueeze(0).expand( |
|
batch_size, -1, -1) |
|
|
|
vertices, joints = lbs( |
|
shape_components, |
|
full_pose, |
|
template_vertices, |
|
self.shapedirs, |
|
self.posedirs, |
|
self.J_regressor, |
|
self.parents, |
|
self.lbs_weights, |
|
dtype=self.dtype, |
|
pose2rot=False, |
|
) |
|
|
|
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( |
|
batch_size, -1) |
|
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand( |
|
batch_size, -1, -1) |
|
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( |
|
vertices, |
|
full_pose, |
|
self.dynamic_lmk_faces_idx, |
|
self.dynamic_lmk_bary_coords, |
|
self.head_kin_chain, |
|
) |
|
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) |
|
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) |
|
landmarks = vertices2landmarks(vertices, self.faces_tensor, |
|
lmk_faces_idx, lmk_bary_coords) |
|
|
|
final_joint_set = [joints, landmarks] |
|
if hasattr(self, "extra_joint_selector"): |
|
|
|
extra_joints = self.extra_joint_selector(vertices, |
|
self.faces_tensor) |
|
final_joint_set.append(extra_joints) |
|
|
|
joints = torch.cat(final_joint_set, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return vertices, landmarks, joints |
|
|
|
def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): |
|
"""change absolute pose to relative pose |
|
Basic knowledge for SMPLX kinematic tree: |
|
absolute pose = parent pose * relative pose |
|
Here, pose must be represented as rotation matrix (batch_sizexnx3x3) |
|
""" |
|
if abs_joint == "head": |
|
|
|
kin_chain = [15, 12, 9, 6, 3, 0] |
|
elif abs_joint == "neck": |
|
|
|
kin_chain = [12, 9, 6, 3, 0] |
|
elif abs_joint == "right_wrist": |
|
|
|
|
|
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] |
|
elif abs_joint == "left_wrist": |
|
|
|
|
|
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] |
|
else: |
|
raise NotImplementedError( |
|
f"pose_abs2rel does not support: {abs_joint}") |
|
|
|
batch_size = global_pose.shape[0] |
|
dtype = global_pose.dtype |
|
device = global_pose.device |
|
full_pose = torch.cat([global_pose, body_pose], dim=1) |
|
rel_rot_mat = (torch.eye(3, device=device, |
|
dtype=dtype).unsqueeze_(dim=0).repeat( |
|
batch_size, 1, 1)) |
|
for idx in kin_chain[1:]: |
|
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) |
|
|
|
|
|
abs_parent_pose = rel_rot_mat.detach() |
|
|
|
abs_joint_pose = body_pose[:, kin_chain[0] - 1] |
|
|
|
rel_joint_pose = torch.matmul( |
|
abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), |
|
abs_joint_pose.reshape(-1, 3, 3), |
|
) |
|
|
|
body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose |
|
return body_pose |
|
|
|
def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): |
|
"""change relative pose to absolute pose |
|
Basic knowledge for SMPLX kinematic tree: |
|
absolute pose = parent pose * relative pose |
|
Here, pose must be represented as rotation matrix (batch_sizexnx3x3) |
|
""" |
|
full_pose = torch.cat([global_pose, body_pose], dim=1) |
|
|
|
if abs_joint == "head": |
|
|
|
kin_chain = [15, 12, 9, 6, 3, 0] |
|
elif abs_joint == "neck": |
|
|
|
kin_chain = [12, 9, 6, 3, 0] |
|
elif abs_joint == "right_wrist": |
|
|
|
|
|
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] |
|
elif abs_joint == "left_wrist": |
|
|
|
|
|
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] |
|
else: |
|
raise NotImplementedError( |
|
f"pose_rel2abs does not support: {abs_joint}") |
|
rel_rot_mat = torch.eye(3, |
|
device=full_pose.device, |
|
dtype=full_pose.dtype).unsqueeze_(dim=0) |
|
for idx in kin_chain: |
|
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) |
|
abs_pose = rel_rot_mat[:, None, :, :] |
|
return abs_pose |
|
|