Spaces:
Running
on
L40S
Running
on
L40S
from collections import namedtuple | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import yaml | |
from torch.nn import functional as F | |
from .layers.Resnet import ResNet | |
from .layers.smpl.SMPL import SMPL_layer | |
ModelOutput = namedtuple( | |
typename='ModelOutput', | |
field_names=[ | |
'pred_shape', 'pred_theta_mats', 'pred_phi', 'pred_delta_shape', | |
'pred_leaf', 'pred_uvd_jts', 'pred_xyz_jts_29', 'pred_xyz_jts_24', | |
'pred_xyz_jts_24_struct', 'pred_xyz_jts_17', 'pred_vertices', | |
'maxvals', 'cam_scale', 'cam_trans', 'cam_root', 'uvd_heatmap', | |
'transl', 'img_feat', 'pred_camera', 'pred_aa' | |
]) | |
ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) | |
def update_config(config_file): | |
with open(config_file) as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
return config | |
def norm_heatmap(norm_type, heatmap): | |
# Input tensor shape: [N,C,...] | |
shape = heatmap.shape | |
if norm_type == 'softmax': | |
heatmap = heatmap.reshape(*shape[:2], -1) | |
# global soft max | |
heatmap = F.softmax(heatmap, 2) | |
return heatmap.reshape(*shape) | |
else: | |
raise NotImplementedError | |
class HybrIKBaseSMPLCam(nn.Module): | |
def __init__(self, | |
cfg_file, | |
smpl_path, | |
data_path, | |
norm_layer=nn.BatchNorm2d): | |
super(HybrIKBaseSMPLCam, self).__init__() | |
cfg = update_config(cfg_file)['MODEL'] | |
self.deconv_dim = cfg['NUM_DECONV_FILTERS'] | |
self._norm_layer = norm_layer | |
self.num_joints = cfg['NUM_JOINTS'] | |
self.norm_type = cfg['POST']['NORM_TYPE'] | |
self.depth_dim = cfg['EXTRA']['DEPTH_DIM'] | |
self.height_dim = cfg['HEATMAP_SIZE'][0] | |
self.width_dim = cfg['HEATMAP_SIZE'][1] | |
self.smpl_dtype = torch.float32 | |
backbone = ResNet | |
self.preact = backbone(f"resnet{cfg['NUM_LAYERS']}") | |
# Imagenet pretrain model | |
import torchvision.models as tm | |
if cfg['NUM_LAYERS'] == 101: | |
''' Load pretrained model ''' | |
x = tm.resnet101(pretrained=True) | |
self.feature_channel = 2048 | |
elif cfg['NUM_LAYERS'] == 50: | |
x = tm.resnet50(pretrained=True) | |
self.feature_channel = 2048 | |
elif cfg['NUM_LAYERS'] == 34: | |
x = tm.resnet34(pretrained=True) | |
self.feature_channel = 512 | |
elif cfg['NUM_LAYERS'] == 18: | |
x = tm.resnet18(pretrained=True) | |
self.feature_channel = 512 | |
else: | |
raise NotImplementedError | |
model_state = self.preact.state_dict() | |
state = { | |
k: v | |
for k, v in x.state_dict().items() | |
if k in self.preact.state_dict() | |
and v.size() == self.preact.state_dict()[k].size() | |
} | |
model_state.update(state) | |
self.preact.load_state_dict(model_state) | |
self.deconv_layers = self._make_deconv_layer() | |
self.final_layer = nn.Conv2d(self.deconv_dim[2], | |
self.num_joints * self.depth_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
h36m_jregressor = np.load( | |
os.path.join(data_path, 'J_regressor_h36m.npy')) | |
self.smpl = SMPL_layer(smpl_path, | |
h36m_jregressor=h36m_jregressor, | |
dtype=self.smpl_dtype) | |
self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), | |
(16, 17), (18, 19), (20, 21), (22, 23)) | |
self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), | |
(16, 17), (18, 19), (20, 21), (22, 23), | |
(25, 26), (27, 28)) | |
self.leaf_pairs = ((0, 1), (3, 4)) | |
self.root_idx_smpl = 0 | |
# mean shape | |
init_shape = np.load(os.path.join(data_path, 'h36m_mean_beta.npy')) | |
self.register_buffer('init_shape', torch.Tensor(init_shape).float()) | |
init_cam = torch.tensor([0.9, 0, 0]) | |
self.register_buffer('init_cam', torch.Tensor(init_cam).float()) | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc1 = nn.Linear(self.feature_channel, 1024) | |
self.drop1 = nn.Dropout(p=0.5) | |
self.fc2 = nn.Linear(1024, 1024) | |
self.drop2 = nn.Dropout(p=0.5) | |
self.decshape = nn.Linear(1024, 10) | |
self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)] | |
self.deccam = nn.Linear(1024, 3) | |
self.focal_length = cfg['FOCAL_LENGTH'] | |
self.input_size = 256.0 | |
def _make_deconv_layer(self): | |
deconv_layers = [] | |
deconv1 = nn.ConvTranspose2d(self.feature_channel, | |
self.deconv_dim[0], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn1 = self._norm_layer(self.deconv_dim[0]) | |
deconv2 = nn.ConvTranspose2d(self.deconv_dim[0], | |
self.deconv_dim[1], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn2 = self._norm_layer(self.deconv_dim[1]) | |
deconv3 = nn.ConvTranspose2d(self.deconv_dim[1], | |
self.deconv_dim[2], | |
kernel_size=4, | |
stride=2, | |
padding=int(4 / 2) - 1, | |
bias=False) | |
bn3 = self._norm_layer(self.deconv_dim[2]) | |
deconv_layers.append(deconv1) | |
deconv_layers.append(bn1) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
deconv_layers.append(deconv2) | |
deconv_layers.append(bn2) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
deconv_layers.append(deconv3) | |
deconv_layers.append(bn3) | |
deconv_layers.append(nn.ReLU(inplace=True)) | |
return nn.Sequential(*deconv_layers) | |
def _initialize(self): | |
for name, m in self.deconv_layers.named_modules(): | |
if isinstance(m, nn.ConvTranspose2d): | |
nn.init.normal_(m.weight, std=0.001) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
for m in self.final_layer.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.normal_(m.weight, std=0.001) | |
nn.init.constant_(m.bias, 0) | |
def flip_uvd_coord(self, pred_jts, shift=False, flatten=True): | |
if flatten: | |
assert pred_jts.dim() == 2 | |
num_batches = pred_jts.shape[0] | |
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) | |
else: | |
assert pred_jts.dim() == 3 | |
num_batches = pred_jts.shape[0] | |
# flip | |
if shift: | |
pred_jts[:, :, 0] = -pred_jts[:, :, 0] | |
else: | |
pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0] | |
for pair in self.joint_pairs_29: | |
dim0, dim1 = pair | |
idx = torch.Tensor((dim0, dim1)).long() | |
inv_idx = torch.Tensor((dim1, dim0)).long() | |
pred_jts[:, idx] = pred_jts[:, inv_idx] | |
if flatten: | |
pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) | |
return pred_jts | |
def flip_xyz_coord(self, pred_jts, flatten=True): | |
if flatten: | |
assert pred_jts.dim() == 2 | |
num_batches = pred_jts.shape[0] | |
pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) | |
else: | |
assert pred_jts.dim() == 3 | |
num_batches = pred_jts.shape[0] | |
pred_jts[:, :, 0] = -pred_jts[:, :, 0] | |
for pair in self.joint_pairs_29: | |
dim0, dim1 = pair | |
idx = torch.Tensor((dim0, dim1)).long() | |
inv_idx = torch.Tensor((dim1, dim0)).long() | |
pred_jts[:, idx] = pred_jts[:, inv_idx] | |
if flatten: | |
pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) | |
return pred_jts | |
def flip_phi(self, pred_phi): | |
pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] | |
for pair in self.joint_pairs_24: | |
dim0, dim1 = pair | |
idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() | |
inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() | |
pred_phi[:, idx] = pred_phi[:, inv_idx] | |
return pred_phi | |
def forward(self, | |
x, | |
flip_item=None, | |
flip_output=False, | |
gt_uvd=None, | |
gt_uvd_weight=None, | |
**kwargs): | |
batch_size = x.shape[0] | |
# torch.cuda.synchronize() | |
# model_start_t = time.time() | |
x0 = self.preact(x) | |
out = self.deconv_layers(x0) | |
out = self.final_layer(out) | |
# torch.cuda.synchronize() | |
# preat_end_t = time.time() | |
out = out.reshape((out.shape[0], self.num_joints, -1)) | |
maxvals, _ = torch.max(out, dim=2, keepdim=True) | |
out = norm_heatmap(self.norm_type, out) | |
assert out.dim() == 3, out.shape | |
heatmaps = out / out.sum(dim=2, keepdim=True) | |
heatmaps = heatmaps.reshape( | |
(heatmaps.shape[0], self.num_joints, self.depth_dim, | |
self.height_dim, self.width_dim)) | |
hm_x0 = heatmaps.sum((2, 3)) | |
hm_y0 = heatmaps.sum((2, 4)) | |
hm_z0 = heatmaps.sum((3, 4)) | |
range_tensor = torch.arange(hm_x0.shape[-1], | |
dtype=torch.float32, | |
device=hm_x0.device) | |
hm_x = hm_x0 * range_tensor | |
hm_y = hm_y0 * range_tensor | |
hm_z = hm_z0 * range_tensor | |
coord_x = hm_x.sum(dim=2, keepdim=True) | |
coord_y = hm_y.sum(dim=2, keepdim=True) | |
coord_z = hm_z.sum(dim=2, keepdim=True) | |
coord_x = coord_x / float(self.width_dim) - 0.5 | |
coord_y = coord_y / float(self.height_dim) - 0.5 | |
coord_z = coord_z / float(self.depth_dim) - 0.5 | |
# -0.5 ~ 0.5 | |
pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2) | |
x0 = self.avg_pool(x0) | |
x0 = x0.view(x0.size(0), -1) | |
init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,) | |
init_cam = self.init_cam.expand(batch_size, -1) # (B, 3,) | |
xc = x0 | |
xc = self.fc1(xc) | |
xc = self.drop1(xc) | |
xc = self.fc2(xc) | |
xc = self.drop2(xc) | |
delta_shape = self.decshape(xc) | |
pred_shape = delta_shape + init_shape | |
pred_phi = self.decphi(xc) | |
pred_camera = self.deccam(xc).reshape(batch_size, -1) + init_cam | |
camScale = pred_camera[:, :1].unsqueeze(1) | |
camTrans = pred_camera[:, 1:].unsqueeze(1) | |
camDepth = self.focal_length / (self.input_size * camScale + 1e-9) | |
pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) | |
pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, | |
2:].clone() # unit: 2.2m | |
pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \ | |
* (pred_xyz_jts_29[:, :, 2:]*2.2 + camDepth) - camTrans # unit: m | |
pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / 2.2 # unit: 2.2m | |
camera_root = pred_xyz_jts_29[:, [0], ] * 2.2 | |
camera_root[:, :, :2] += camTrans | |
camera_root[:, :, [2]] += camDepth | |
if not self.training: | |
pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] | |
if flip_item is not None: | |
assert flip_output is not None | |
pred_xyz_jts_29_orig, pred_phi_orig, pred_leaf_orig, pred_shape_orig = flip_item | |
if flip_output: | |
pred_xyz_jts_29 = self.flip_xyz_coord(pred_xyz_jts_29, | |
flatten=False) | |
if flip_output and flip_item is not None: | |
pred_xyz_jts_29 = (pred_xyz_jts_29 + pred_xyz_jts_29_orig.reshape( | |
batch_size, 29, 3)) / 2 | |
pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) | |
pred_phi = pred_phi.reshape(batch_size, 23, 2) | |
if flip_output: | |
pred_phi = self.flip_phi(pred_phi) | |
if flip_output and flip_item is not None: | |
pred_phi = (pred_phi + pred_phi_orig) / 2 | |
pred_shape = (pred_shape + pred_shape_orig) / 2 | |
output = self.smpl.hybrik( | |
pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * | |
2.2, # unit: meter | |
betas=pred_shape.type(self.smpl_dtype), | |
phis=pred_phi.type(self.smpl_dtype), | |
global_orient=None, | |
return_verts=True) | |
pred_vertices = output.vertices.float() | |
# -0.5 ~ 0.5 | |
# pred_xyz_jts_24_struct = output.joints.float() / 2.2 | |
pred_xyz_jts_24_struct = output.joints.float() / 2 | |
# -0.5 ~ 0.5 | |
# pred_xyz_jts_17 = output.joints_from_verts.float() / 2.2 | |
pred_xyz_jts_17 = output.joints_from_verts.float() / 2 | |
pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24, 3, 3) | |
pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, | |
72) / 2 | |
pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) | |
pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) | |
transl = pred_xyz_jts_29[:, 0, :] * \ | |
2.2 - pred_xyz_jts_17[:, 0, :] * 2.2 | |
transl[:, :2] += camTrans[:, 0] | |
transl[:, 2] += camDepth[:, 0, 0] | |
new_cam = torch.zeros_like(transl) | |
new_cam[:, 1:] = transl[:, :2] | |
new_cam[:, 0] = self.focal_length / \ | |
(self.input_size * transl[:, 2] + 1e-9) | |
# pred_aa = output.rot_aa.reshape(batch_size, 24, 3) | |
output = dict( | |
pred_phi=pred_phi, | |
pred_delta_shape=delta_shape, | |
pred_shape=pred_shape, | |
# pred_aa=pred_aa, | |
pred_theta_mats=pred_theta_mats, | |
pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), | |
pred_xyz_jts_29=pred_xyz_jts_29_flat, | |
pred_xyz_jts_24=pred_xyz_jts_24, | |
pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, | |
pred_xyz_jts_17=pred_xyz_jts_17_flat, | |
pred_vertices=pred_vertices, | |
maxvals=maxvals, | |
cam_scale=camScale[:, 0], | |
cam_trans=camTrans[:, 0], | |
cam_root=camera_root, | |
pred_camera=new_cam, | |
transl=transl, | |
# uvd_heatmap=torch.stack([hm_x0, hm_y0, hm_z0], dim=2), | |
# uvd_heatmap=heatmaps, | |
# img_feat=x0 | |
) | |
return output | |
def forward_gt_theta(self, gt_theta, gt_beta): | |
output = self.smpl(pose_axis_angle=gt_theta, | |
betas=gt_beta, | |
global_orient=None, | |
return_verts=True) | |
return output | |