""" original from https://github.com/vchoutas/smplx modified by Vassilis and Yao """ import torch import torch.nn as nn import numpy as np import pickle from .lbs import ( Struct, to_tensor, to_np, lbs, vertices2landmarks, JointsFromVerticesSelector, find_dynamic_lmk_idx_and_bcoords, ) # SMPLX J14_NAMES = [ "right_ankle", "right_knee", "right_hip", "left_hip", "left_knee", "left_ankle", "right_wrist", "right_elbow", "right_shoulder", "left_shoulder", "left_elbow", "left_wrist", "neck", "head", ] SMPLX_names = [ "pelvis", "left_hip", "right_hip", "spine1", "left_knee", "right_knee", "spine2", "left_ankle", "right_ankle", "spine3", "left_foot", "right_foot", "neck", "left_collar", "right_collar", "head", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", "jaw", "left_eye_smplx", "right_eye_smplx", "left_index1", "left_index2", "left_index3", "left_middle1", "left_middle2", "left_middle3", "left_pinky1", "left_pinky2", "left_pinky3", "left_ring1", "left_ring2", "left_ring3", "left_thumb1", "left_thumb2", "left_thumb3", "right_index1", "right_index2", "right_index3", "right_middle1", "right_middle2", "right_middle3", "right_pinky1", "right_pinky2", "right_pinky3", "right_ring1", "right_ring2", "right_ring3", "right_thumb1", "right_thumb2", "right_thumb3", "right_eye_brow1", "right_eye_brow2", "right_eye_brow3", "right_eye_brow4", "right_eye_brow5", "left_eye_brow5", "left_eye_brow4", "left_eye_brow3", "left_eye_brow2", "left_eye_brow1", "nose1", "nose2", "nose3", "nose4", "right_nose_2", "right_nose_1", "nose_middle", "left_nose_1", "left_nose_2", "right_eye1", "right_eye2", "right_eye3", "right_eye4", "right_eye5", "right_eye6", "left_eye4", "left_eye3", "left_eye2", "left_eye1", "left_eye6", "left_eye5", "right_mouth_1", "right_mouth_2", "right_mouth_3", "mouth_top", "left_mouth_3", "left_mouth_2", "left_mouth_1", "left_mouth_5", "left_mouth_4", "mouth_bottom", "right_mouth_4", "right_mouth_5", "right_lip_1", "right_lip_2", "lip_top", "left_lip_2", "left_lip_1", "left_lip_3", "lip_bottom", "right_lip_3", "right_contour_1", "right_contour_2", "right_contour_3", "right_contour_4", "right_contour_5", "right_contour_6", "right_contour_7", "right_contour_8", "contour_middle", "left_contour_8", "left_contour_7", "left_contour_6", "left_contour_5", "left_contour_4", "left_contour_3", "left_contour_2", "left_contour_1", "head_top", "left_big_toe", "left_ear", "left_eye", "left_heel", "left_index", "left_middle", "left_pinky", "left_ring", "left_small_toe", "left_thumb", "nose", "right_big_toe", "right_ear", "right_eye", "right_heel", "right_index", "right_middle", "right_pinky", "right_ring", "right_small_toe", "right_thumb", ] extra_names = [ "head_top", "left_big_toe", "left_ear", "left_eye", "left_heel", "left_index", "left_middle", "left_pinky", "left_ring", "left_small_toe", "left_thumb", "nose", "right_big_toe", "right_ear", "right_eye", "right_heel", "right_index", "right_middle", "right_pinky", "right_ring", "right_small_toe", "right_thumb", ] SMPLX_names += extra_names part_indices = {} part_indices["body"] = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 123, 124, 125, 126, 127, 132, 134, 135, 136, 137, 138, 143, ]) part_indices["torso"] = np.array([ 0, 1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 55, 56, 57, 58, 59, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, ]) part_indices["head"] = np.array([ 12, 15, 22, 23, 24, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 134, 136, 137, ]) part_indices["face"] = np.array([ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, ]) part_indices["upper"] = np.array([ 12, 13, 14, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, ]) part_indices["hand"] = np.array([ 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 128, 129, 130, 131, 133, 139, 140, 141, 142, 144, ]) part_indices["left_hand"] = np.array([ 20, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 128, 129, 130, 131, 133, ]) part_indices["right_hand"] = np.array([ 21, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 139, 140, 141, 142, 144, ]) # kinematic tree head_kin_chain = [15, 12, 9, 6, 3, 0] # --smplx joints # 00 - Global # 01 - L_Thigh # 02 - R_Thigh # 03 - Spine # 04 - L_Calf # 05 - R_Calf # 06 - Spine1 # 07 - L_Foot # 08 - R_Foot # 09 - Spine2 # 10 - L_Toes # 11 - R_Toes # 12 - Neck # 13 - L_Shoulder # 14 - R_Shoulder # 15 - Head # 16 - L_UpperArm # 17 - R_UpperArm # 18 - L_ForeArm # 19 - R_ForeArm # 20 - L_Hand # 21 - R_Hand # 22 - Jaw # 23 - L_Eye # 24 - R_Eye class SMPLX(nn.Module): """ Given smplx parameters, this class generates a differentiable SMPLX function which outputs a mesh and 3D joints """ def __init__(self, config): super(SMPLX, self).__init__() # print("creating the SMPLX Decoder") ss = np.load(config.smplx_model_path, allow_pickle=True) smplx_model = Struct(**ss) self.dtype = torch.float32 self.register_buffer( "faces_tensor", to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), ) # The vertices of the template model self.register_buffer( "v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) # The shape components and expression # expression space is the same as FLAME shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) shapedirs = torch.cat( [ shapedirs[:, :, :config.n_shape], shapedirs[:, :, 300:300 + config.n_exp], ], 2, ) self.register_buffer("shapedirs", shapedirs) # The pose components num_pose_basis = smplx_model.posedirs.shape[-1] posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype)) self.register_buffer( "J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() parents[0] = -1 self.register_buffer("parents", parents) self.register_buffer( "lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) # for face keypoints self.register_buffer( "lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) self.register_buffer( "lmk_bary_coords", torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), ) self.register_buffer( "dynamic_lmk_faces_idx", torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), ) self.register_buffer( "dynamic_lmk_bary_coords", torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype), ) # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long)) # -- initialize parameters # shape and expression self.register_buffer( "shape_params", nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False), ) self.register_buffer( "expression_params", nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False), ) # pose: represented as rotation matrx [number of joints, 3, 3] self.register_buffer( "global_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), requires_grad=False, ), ) self.register_buffer( "head_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), requires_grad=False, ), ) self.register_buffer( "neck_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), requires_grad=False, ), ) self.register_buffer( "jaw_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), requires_grad=False, ), ) self.register_buffer( "eye_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), requires_grad=False, ), ) self.register_buffer( "body_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), requires_grad=False, ), ) self.register_buffer( "left_hand_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), requires_grad=False, ), ) self.register_buffer( "right_hand_pose", nn.Parameter( torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), requires_grad=False, ), ) if config.extra_joint_path: self.extra_joint_selector = JointsFromVerticesSelector( fname=config.extra_joint_path) self.use_joint_regressor = True self.keypoint_names = SMPLX_names if self.use_joint_regressor: with open(config.j14_regressor_path, "rb") as f: j14_regressor = pickle.load(f, encoding="latin1") source = [] target = [] for idx, name in enumerate(self.keypoint_names): if name in J14_NAMES: source.append(idx) target.append(J14_NAMES.index(name)) source = np.asarray(source) target = np.asarray(target) self.register_buffer("source_idxs", torch.from_numpy(source)) self.register_buffer("target_idxs", torch.from_numpy(target)) joint_regressor = torch.from_numpy(j14_regressor).to( dtype=torch.float32) self.register_buffer("extra_joint_regressor", joint_regressor) self.part_indices = part_indices def forward( self, shape_params=None, expression_params=None, global_pose=None, body_pose=None, jaw_pose=None, eye_pose=None, left_hand_pose=None, right_hand_pose=None, ): """ Args: shape_params: [N, number of shape parameters] expression_params: [N, number of expression parameters] global_pose: pelvis pose, [N, 1, 3, 3] body_pose: [N, 21, 3, 3] jaw_pose: [N, 1, 3, 3] eye_pose: [N, 2, 3, 3] left_hand_pose: [N, 15, 3, 3] right_hand_pose: [N, 15, 3, 3] Returns: vertices: [N, number of vertices, 3] landmarks: [N, number of landmarks (68 face keypoints), 3] joints: [N, number of smplx joints (145), 3] """ if shape_params is None: batch_size = global_pose.shape[0] shape_params = self.shape_params.expand(batch_size, -1) else: batch_size = shape_params.shape[0] if expression_params is None: expression_params = self.expression_params.expand(batch_size, -1) if global_pose is None: global_pose = self.global_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) if body_pose is None: body_pose = self.body_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) if jaw_pose is None: jaw_pose = self.jaw_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) if eye_pose is None: eye_pose = self.eye_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) if left_hand_pose is None: left_hand_pose = self.left_hand_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) if right_hand_pose is None: right_hand_pose = self.right_hand_pose.unsqueeze(0).expand( batch_size, -1, -1, -1) shape_components = torch.cat([shape_params, expression_params], dim=1) full_pose = torch.cat( [ global_pose, body_pose, jaw_pose, eye_pose, left_hand_pose, right_hand_pose, ], dim=1, ) template_vertices = self.v_template.unsqueeze(0).expand( batch_size, -1, -1) # smplx vertices, joints = lbs( shape_components, full_pose, template_vertices, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, dtype=self.dtype, pose2rot=False, ) # face dynamic landmarks lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand( batch_size, -1) lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand( batch_size, -1, -1) dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.head_kin_chain, ) lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) final_joint_set = [joints, landmarks] if hasattr(self, "extra_joint_selector"): # Add any extra joints that might be needed extra_joints = self.extra_joint_selector(vertices, self.faces_tensor) final_joint_set.append(extra_joints) # Create the final joint set joints = torch.cat(final_joint_set, dim=1) # if self.use_joint_regressor: # reg_joints = torch.einsum("ji,bik->bjk", # self.extra_joint_regressor, vertices) # joints[:, self.source_idxs] = ( # joints[:, self.source_idxs].detach() * 0.0 + # reg_joints[:, self.target_idxs] * 1.0) return vertices, landmarks, joints def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): """change absolute pose to relative pose Basic knowledge for SMPLX kinematic tree: absolute pose = parent pose * relative pose Here, pose must be represented as rotation matrix (batch_sizexnx3x3) """ if abs_joint == "head": # Pelvis -> Spine 1, 2, 3 -> Neck -> Head kin_chain = [15, 12, 9, 6, 3, 0] elif abs_joint == "neck": # Pelvis -> Spine 1, 2, 3 -> Neck -> Head kin_chain = [12, 9, 6, 3, 0] elif abs_joint == "right_wrist": # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder # -> right elbow -> right wrist kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] elif abs_joint == "left_wrist": # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder # -> Left elbow -> Left wrist kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] else: raise NotImplementedError( f"pose_abs2rel does not support: {abs_joint}") batch_size = global_pose.shape[0] dtype = global_pose.dtype device = global_pose.device full_pose = torch.cat([global_pose, body_pose], dim=1) rel_rot_mat = (torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat( batch_size, 1, 1)) for idx in kin_chain[1:]: rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) # This contains the absolute pose of the parent abs_parent_pose = rel_rot_mat.detach() # Let's assume that in the input this specific joint is predicted as an absolute value abs_joint_pose = body_pose[:, kin_chain[0] - 1] # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head rel_joint_pose = torch.matmul( abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), abs_joint_pose.reshape(-1, 3, 3), ) # Replace the new relative pose body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose return body_pose def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): """change relative pose to absolute pose Basic knowledge for SMPLX kinematic tree: absolute pose = parent pose * relative pose Here, pose must be represented as rotation matrix (batch_sizexnx3x3) """ full_pose = torch.cat([global_pose, body_pose], dim=1) if abs_joint == "head": # Pelvis -> Spine 1, 2, 3 -> Neck -> Head kin_chain = [15, 12, 9, 6, 3, 0] elif abs_joint == "neck": # Pelvis -> Spine 1, 2, 3 -> Neck -> Head kin_chain = [12, 9, 6, 3, 0] elif abs_joint == "right_wrist": # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder # -> right elbow -> right wrist kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] elif abs_joint == "left_wrist": # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder # -> Left elbow -> Left wrist kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] else: raise NotImplementedError( f"pose_rel2abs does not support: {abs_joint}") rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0) for idx in kin_chain: rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) abs_pose = rel_rot_mat[:, None, :, :] return abs_pose