|
''' |
|
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py |
|
''' |
|
import os |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis |
|
|
|
from core import path_config, constants |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class FitsDict(): |
|
""" Dictionary keeping track of the best fit per image in the training set """ |
|
|
|
def __init__(self, options, train_dataset): |
|
self.options = options |
|
self.train_dataset = train_dataset |
|
self.fits_dict = {} |
|
self.valid_fit_state = {} |
|
|
|
self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM, |
|
dtype=torch.int64) |
|
|
|
for ds_name, ds in train_dataset.dataset_dict.items(): |
|
if ds_name in ['h36m']: |
|
dict_file = os.path.join(path_config.FINAL_FITS_DIR, |
|
ds_name + '.npy') |
|
self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file)) |
|
self.valid_fit_state[ds_name] = torch.ones(len( |
|
self.fits_dict[ds_name]), |
|
dtype=torch.uint8) |
|
else: |
|
dict_file = os.path.join(path_config.FINAL_FITS_DIR, |
|
ds_name + '.npz') |
|
fits_dict = np.load(dict_file) |
|
opt_pose = torch.from_numpy(fits_dict['pose']) |
|
opt_betas = torch.from_numpy(fits_dict['betas']) |
|
opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to( |
|
torch.uint8) |
|
self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas], |
|
dim=1) |
|
self.valid_fit_state[ds_name] = opt_valid_fit |
|
|
|
if not options.single_dataset: |
|
for ds in train_dataset.datasets: |
|
if ds.dataset not in ['h36m']: |
|
ds.pose = self.fits_dict[ds.dataset][:, :72].numpy() |
|
ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy() |
|
ds.has_smpl = self.valid_fit_state[ds.dataset].numpy() |
|
|
|
def save(self): |
|
""" Save dictionary state to disk """ |
|
for ds_name in self.train_dataset.dataset_dict.keys(): |
|
dict_file = os.path.join(self.options.checkpoint_dir, |
|
ds_name + '_fits.npy') |
|
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) |
|
|
|
def __getitem__(self, x): |
|
""" Retrieve dictionary entries """ |
|
dataset_name, ind, rot, is_flipped = x |
|
batch_size = len(dataset_name) |
|
pose = torch.zeros((batch_size, 72)) |
|
betas = torch.zeros((batch_size, 10)) |
|
for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
|
params = self.fits_dict[ds][i] |
|
pose[n, :] = params[:72] |
|
betas[n, :] = params[72:] |
|
pose = pose.clone() |
|
|
|
pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped) |
|
betas = betas.clone() |
|
return pose, betas |
|
|
|
def get_vaild_state(self, dataset_name, ind): |
|
batch_size = len(dataset_name) |
|
valid_fit = torch.zeros(batch_size, dtype=torch.uint8) |
|
for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
|
valid_fit[n] = self.valid_fit_state[ds][i] |
|
valid_fit = valid_fit.clone() |
|
return valid_fit |
|
|
|
def __setitem__(self, x, val): |
|
""" Update dictionary entries """ |
|
dataset_name, ind, rot, is_flipped, update = x |
|
pose, betas = val |
|
batch_size = len(dataset_name) |
|
|
|
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot) |
|
params = torch.cat((pose, betas), dim=-1).cpu() |
|
for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
|
if update[n]: |
|
self.fits_dict[ds][i] = params[n] |
|
|
|
def flip_pose(self, pose, is_flipped): |
|
"""flip SMPL pose parameters""" |
|
is_flipped = is_flipped.byte() |
|
pose_f = pose.clone() |
|
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] |
|
|
|
pose_f[is_flipped, 1::3] *= -1 |
|
pose_f[is_flipped, 2::3] *= -1 |
|
return pose_f |
|
|
|
def rotate_pose(self, pose, rot): |
|
"""Rotate SMPL pose parameters by rot degrees""" |
|
pose = pose.clone() |
|
cos = torch.cos(-np.pi * rot / 180.) |
|
sin = torch.sin(-np.pi * rot / 180.) |
|
zeros = torch.zeros_like(cos) |
|
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) |
|
r3[:, 0, -1] = 1 |
|
R = torch.cat([ |
|
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), |
|
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 |
|
], |
|
dim=1) |
|
global_pose = pose[:, :3] |
|
global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose) |
|
global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3] |
|
global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3) |
|
global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3 |
|
global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy() |
|
global_pose_np = np.zeros((global_pose.shape[0], 3)) |
|
for i in range(global_pose.shape[0]): |
|
aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) |
|
global_pose_np[i, :] = aa.squeeze() |
|
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) |
|
return pose |
|
|