Spaces:
Runtime error
Runtime error
''' | |
This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py | |
''' | |
import os | |
import cv2 | |
import torch | |
import numpy as np | |
from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis | |
from core import path_config, constants | |
import logging | |
logger = logging.getLogger(__name__) | |
class FitsDict(): | |
""" Dictionary keeping track of the best fit per image in the training set """ | |
def __init__(self, options, train_dataset): | |
self.options = options | |
self.train_dataset = train_dataset | |
self.fits_dict = {} | |
self.valid_fit_state = {} | |
# array used to flip SMPL pose parameters | |
self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM, | |
dtype=torch.int64) | |
# Load dictionary state | |
for ds_name, ds in train_dataset.dataset_dict.items(): | |
if ds_name in ['h36m']: | |
dict_file = os.path.join(path_config.FINAL_FITS_DIR, | |
ds_name + '.npy') | |
self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file)) | |
self.valid_fit_state[ds_name] = torch.ones(len( | |
self.fits_dict[ds_name]), | |
dtype=torch.uint8) | |
else: | |
dict_file = os.path.join(path_config.FINAL_FITS_DIR, | |
ds_name + '.npz') | |
fits_dict = np.load(dict_file) | |
opt_pose = torch.from_numpy(fits_dict['pose']) | |
opt_betas = torch.from_numpy(fits_dict['betas']) | |
opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to( | |
torch.uint8) | |
self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas], | |
dim=1) | |
self.valid_fit_state[ds_name] = opt_valid_fit | |
if not options.single_dataset: | |
for ds in train_dataset.datasets: | |
if ds.dataset not in ['h36m']: | |
ds.pose = self.fits_dict[ds.dataset][:, :72].numpy() | |
ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy() | |
ds.has_smpl = self.valid_fit_state[ds.dataset].numpy() | |
def save(self): | |
""" Save dictionary state to disk """ | |
for ds_name in self.train_dataset.dataset_dict.keys(): | |
dict_file = os.path.join(self.options.checkpoint_dir, | |
ds_name + '_fits.npy') | |
np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) | |
def __getitem__(self, x): | |
""" Retrieve dictionary entries """ | |
dataset_name, ind, rot, is_flipped = x | |
batch_size = len(dataset_name) | |
pose = torch.zeros((batch_size, 72)) | |
betas = torch.zeros((batch_size, 10)) | |
for ds, i, n in zip(dataset_name, ind, range(batch_size)): | |
params = self.fits_dict[ds][i] | |
pose[n, :] = params[:72] | |
betas[n, :] = params[72:] | |
pose = pose.clone() | |
# Apply flipping and rotation | |
pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped) | |
betas = betas.clone() | |
return pose, betas | |
def get_vaild_state(self, dataset_name, ind): | |
batch_size = len(dataset_name) | |
valid_fit = torch.zeros(batch_size, dtype=torch.uint8) | |
for ds, i, n in zip(dataset_name, ind, range(batch_size)): | |
valid_fit[n] = self.valid_fit_state[ds][i] | |
valid_fit = valid_fit.clone() | |
return valid_fit | |
def __setitem__(self, x, val): | |
""" Update dictionary entries """ | |
dataset_name, ind, rot, is_flipped, update = x | |
pose, betas = val | |
batch_size = len(dataset_name) | |
# Undo flipping and rotation | |
pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot) | |
params = torch.cat((pose, betas), dim=-1).cpu() | |
for ds, i, n in zip(dataset_name, ind, range(batch_size)): | |
if update[n]: | |
self.fits_dict[ds][i] = params[n] | |
def flip_pose(self, pose, is_flipped): | |
"""flip SMPL pose parameters""" | |
is_flipped = is_flipped.byte() | |
pose_f = pose.clone() | |
pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] | |
# we also negate the second and the third dimension of the axis-angle representation | |
pose_f[is_flipped, 1::3] *= -1 | |
pose_f[is_flipped, 2::3] *= -1 | |
return pose_f | |
def rotate_pose(self, pose, rot): | |
"""Rotate SMPL pose parameters by rot degrees""" | |
pose = pose.clone() | |
cos = torch.cos(-np.pi * rot / 180.) | |
sin = torch.sin(-np.pi * rot / 180.) | |
zeros = torch.zeros_like(cos) | |
r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) | |
r3[:, 0, -1] = 1 | |
R = torch.cat([ | |
torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), | |
torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 | |
], | |
dim=1) | |
global_pose = pose[:, :3] | |
global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose) | |
global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3] | |
global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3) | |
global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3 | |
global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy() | |
global_pose_np = np.zeros((global_pose.shape[0], 3)) | |
for i in range(global_pose.shape[0]): | |
aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) | |
global_pose_np[i, :] = aa.squeeze() | |
pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) | |
return pose | |