from collections import namedtuple import os import numpy as np import torch import torch.nn as nn import yaml from torch.nn import functional as F from .layers.Resnet import ResNet from .layers.smpl.SMPL import SMPL_layer ModelOutput = namedtuple( typename='ModelOutput', field_names=[ 'pred_shape', 'pred_theta_mats', 'pred_phi', 'pred_delta_shape', 'pred_leaf', 'pred_uvd_jts', 'pred_xyz_jts_29', 'pred_xyz_jts_24', 'pred_xyz_jts_24_struct', 'pred_xyz_jts_17', 'pred_vertices', 'maxvals', 'cam_scale', 'cam_trans', 'cam_root', 'uvd_heatmap', 'transl', 'img_feat', 'pred_camera', 'pred_aa' ]) ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) def update_config(config_file): with open(config_file) as f: config = yaml.load(f, Loader=yaml.FullLoader) return config def norm_heatmap(norm_type, heatmap): # Input tensor shape: [N,C,...] shape = heatmap.shape if norm_type == 'softmax': heatmap = heatmap.reshape(*shape[:2], -1) # global soft max heatmap = F.softmax(heatmap, 2) return heatmap.reshape(*shape) else: raise NotImplementedError class HybrIKBaseSMPLCam(nn.Module): def __init__(self, cfg_file, smpl_path, data_path, norm_layer=nn.BatchNorm2d): super(HybrIKBaseSMPLCam, self).__init__() cfg = update_config(cfg_file)['MODEL'] self.deconv_dim = cfg['NUM_DECONV_FILTERS'] self._norm_layer = norm_layer self.num_joints = cfg['NUM_JOINTS'] self.norm_type = cfg['POST']['NORM_TYPE'] self.depth_dim = cfg['EXTRA']['DEPTH_DIM'] self.height_dim = cfg['HEATMAP_SIZE'][0] self.width_dim = cfg['HEATMAP_SIZE'][1] self.smpl_dtype = torch.float32 backbone = ResNet self.preact = backbone(f"resnet{cfg['NUM_LAYERS']}") # Imagenet pretrain model import torchvision.models as tm if cfg['NUM_LAYERS'] == 101: ''' Load pretrained model ''' x = tm.resnet101(pretrained=True) self.feature_channel = 2048 elif cfg['NUM_LAYERS'] == 50: x = tm.resnet50(pretrained=True) self.feature_channel = 2048 elif cfg['NUM_LAYERS'] == 34: x = tm.resnet34(pretrained=True) self.feature_channel = 512 elif cfg['NUM_LAYERS'] == 18: x = tm.resnet18(pretrained=True) self.feature_channel = 512 else: raise NotImplementedError model_state = self.preact.state_dict() state = { k: v for k, v in x.state_dict().items() if k in self.preact.state_dict() and v.size() == self.preact.state_dict()[k].size() } model_state.update(state) self.preact.load_state_dict(model_state) self.deconv_layers = self._make_deconv_layer() self.final_layer = nn.Conv2d(self.deconv_dim[2], self.num_joints * self.depth_dim, kernel_size=1, stride=1, padding=0) h36m_jregressor = np.load( os.path.join(data_path, 'J_regressor_h36m.npy')) self.smpl = SMPL_layer(smpl_path, h36m_jregressor=h36m_jregressor, dtype=self.smpl_dtype) self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23)) self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23), (25, 26), (27, 28)) self.leaf_pairs = ((0, 1), (3, 4)) self.root_idx_smpl = 0 # mean shape init_shape = np.load(os.path.join(data_path, 'h36m_mean_beta.npy')) self.register_buffer('init_shape', torch.Tensor(init_shape).float()) init_cam = torch.tensor([0.9, 0, 0]) self.register_buffer('init_cam', torch.Tensor(init_cam).float()) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Linear(self.feature_channel, 1024) self.drop1 = nn.Dropout(p=0.5) self.fc2 = nn.Linear(1024, 1024) self.drop2 = nn.Dropout(p=0.5) self.decshape = nn.Linear(1024, 10) self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)] self.deccam = nn.Linear(1024, 3) self.focal_length = cfg['FOCAL_LENGTH'] self.input_size = 256.0 def _make_deconv_layer(self): deconv_layers = [] deconv1 = nn.ConvTranspose2d(self.feature_channel, self.deconv_dim[0], kernel_size=4, stride=2, padding=int(4 / 2) - 1, bias=False) bn1 = self._norm_layer(self.deconv_dim[0]) deconv2 = nn.ConvTranspose2d(self.deconv_dim[0], self.deconv_dim[1], kernel_size=4, stride=2, padding=int(4 / 2) - 1, bias=False) bn2 = self._norm_layer(self.deconv_dim[1]) deconv3 = nn.ConvTranspose2d(self.deconv_dim[1], self.deconv_dim[2], kernel_size=4, stride=2, padding=int(4 / 2) - 1, bias=False) bn3 = self._norm_layer(self.deconv_dim[2]) deconv_layers.append(deconv1) deconv_layers.append(bn1) deconv_layers.append(nn.ReLU(inplace=True)) deconv_layers.append(deconv2) deconv_layers.append(bn2) deconv_layers.append(nn.ReLU(inplace=True)) deconv_layers.append(deconv3) deconv_layers.append(bn3) deconv_layers.append(nn.ReLU(inplace=True)) return nn.Sequential(*deconv_layers) def _initialize(self): for name, m in self.deconv_layers.named_modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight, std=0.001) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) for m in self.final_layer.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) nn.init.constant_(m.bias, 0) def flip_uvd_coord(self, pred_jts, shift=False, flatten=True): if flatten: assert pred_jts.dim() == 2 num_batches = pred_jts.shape[0] pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) else: assert pred_jts.dim() == 3 num_batches = pred_jts.shape[0] # flip if shift: pred_jts[:, :, 0] = -pred_jts[:, :, 0] else: pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0] for pair in self.joint_pairs_29: dim0, dim1 = pair idx = torch.Tensor((dim0, dim1)).long() inv_idx = torch.Tensor((dim1, dim0)).long() pred_jts[:, idx] = pred_jts[:, inv_idx] if flatten: pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) return pred_jts def flip_xyz_coord(self, pred_jts, flatten=True): if flatten: assert pred_jts.dim() == 2 num_batches = pred_jts.shape[0] pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) else: assert pred_jts.dim() == 3 num_batches = pred_jts.shape[0] pred_jts[:, :, 0] = -pred_jts[:, :, 0] for pair in self.joint_pairs_29: dim0, dim1 = pair idx = torch.Tensor((dim0, dim1)).long() inv_idx = torch.Tensor((dim1, dim0)).long() pred_jts[:, idx] = pred_jts[:, inv_idx] if flatten: pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) return pred_jts def flip_phi(self, pred_phi): pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] for pair in self.joint_pairs_24: dim0, dim1 = pair idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() pred_phi[:, idx] = pred_phi[:, inv_idx] return pred_phi def forward(self, x, flip_item=None, flip_output=False, gt_uvd=None, gt_uvd_weight=None, **kwargs): batch_size = x.shape[0] # torch.cuda.synchronize() # model_start_t = time.time() x0 = self.preact(x) out = self.deconv_layers(x0) out = self.final_layer(out) # torch.cuda.synchronize() # preat_end_t = time.time() out = out.reshape((out.shape[0], self.num_joints, -1)) maxvals, _ = torch.max(out, dim=2, keepdim=True) out = norm_heatmap(self.norm_type, out) assert out.dim() == 3, out.shape heatmaps = out / out.sum(dim=2, keepdim=True) heatmaps = heatmaps.reshape( (heatmaps.shape[0], self.num_joints, self.depth_dim, self.height_dim, self.width_dim)) hm_x0 = heatmaps.sum((2, 3)) hm_y0 = heatmaps.sum((2, 4)) hm_z0 = heatmaps.sum((3, 4)) range_tensor = torch.arange(hm_x0.shape[-1], dtype=torch.float32, device=hm_x0.device) hm_x = hm_x0 * range_tensor hm_y = hm_y0 * range_tensor hm_z = hm_z0 * range_tensor coord_x = hm_x.sum(dim=2, keepdim=True) coord_y = hm_y.sum(dim=2, keepdim=True) coord_z = hm_z.sum(dim=2, keepdim=True) coord_x = coord_x / float(self.width_dim) - 0.5 coord_y = coord_y / float(self.height_dim) - 0.5 coord_z = coord_z / float(self.depth_dim) - 0.5 # -0.5 ~ 0.5 pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2) x0 = self.avg_pool(x0) x0 = x0.view(x0.size(0), -1) init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) init_cam = self.init_cam.expand(batch_size, -1) # (B, 3,) xc = x0 xc = self.fc1(xc) xc = self.drop1(xc) xc = self.fc2(xc) xc = self.drop2(xc) delta_shape = self.decshape(xc) pred_shape = delta_shape + init_shape pred_phi = self.decphi(xc) pred_camera = self.deccam(xc).reshape(batch_size, -1) + init_cam camScale = pred_camera[:, :1].unsqueeze(1) camTrans = pred_camera[:, 1:].unsqueeze(1) camDepth = self.focal_length / (self.input_size * camScale + 1e-9) pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, 2:].clone() # unit: 2.2m pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \ * (pred_xyz_jts_29[:, :, 2:]*2.2 + camDepth) - camTrans # unit: m pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / 2.2 # unit: 2.2m camera_root = pred_xyz_jts_29[:, [0], ] * 2.2 camera_root[:, :, :2] += camTrans camera_root[:, :, [2]] += camDepth if not self.training: pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] if flip_item is not None: assert flip_output is not None pred_xyz_jts_29_orig, pred_phi_orig, pred_leaf_orig, pred_shape_orig = flip_item if flip_output: pred_xyz_jts_29 = self.flip_xyz_coord(pred_xyz_jts_29, flatten=False) if flip_output and flip_item is not None: pred_xyz_jts_29 = (pred_xyz_jts_29 + pred_xyz_jts_29_orig.reshape( batch_size, 29, 3)) / 2 pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) pred_phi = pred_phi.reshape(batch_size, 23, 2) if flip_output: pred_phi = self.flip_phi(pred_phi) if flip_output and flip_item is not None: pred_phi = (pred_phi + pred_phi_orig) / 2 pred_shape = (pred_shape + pred_shape_orig) / 2 output = self.smpl.hybrik( pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * 2.2, # unit: meter betas=pred_shape.type(self.smpl_dtype), phis=pred_phi.type(self.smpl_dtype), global_orient=None, return_verts=True) pred_vertices = output.vertices.float() # -0.5 ~ 0.5 # pred_xyz_jts_24_struct = output.joints.float() / 2.2 pred_xyz_jts_24_struct = output.joints.float() / 2 # -0.5 ~ 0.5 # pred_xyz_jts_17 = output.joints_from_verts.float() / 2.2 pred_xyz_jts_17 = output.joints_from_verts.float() / 2 pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24, 3, 3) pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72) / 2 pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) transl = pred_xyz_jts_29[:, 0, :] * \ 2.2 - pred_xyz_jts_17[:, 0, :] * 2.2 transl[:, :2] += camTrans[:, 0] transl[:, 2] += camDepth[:, 0, 0] new_cam = torch.zeros_like(transl) new_cam[:, 1:] = transl[:, :2] new_cam[:, 0] = self.focal_length / \ (self.input_size * transl[:, 2] + 1e-9) # pred_aa = output.rot_aa.reshape(batch_size, 24, 3) output = dict( pred_phi=pred_phi, pred_delta_shape=delta_shape, pred_shape=pred_shape, # pred_aa=pred_aa, pred_theta_mats=pred_theta_mats, pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), pred_xyz_jts_29=pred_xyz_jts_29_flat, pred_xyz_jts_24=pred_xyz_jts_24, pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, pred_xyz_jts_17=pred_xyz_jts_17_flat, pred_vertices=pred_vertices, maxvals=maxvals, cam_scale=camScale[:, 0], cam_trans=camTrans[:, 0], cam_root=camera_root, pred_camera=new_cam, transl=transl, # uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), # uvd_heatmap=heatmaps, # img_feat=x0 ) return output def forward_gt_theta(self, gt_theta, gt_beta): output = self.smpl(pose_axis_angle=gt_theta, betas=gt_beta, global_orient=None, return_verts=True) return output