trumans / sample_hsi.py
jnnan's picture
Upload 68 files
aeba71c verified
raw
history blame
13.4 kB
import os
import pdb
import pickle as pkl
from torch.utils.data import Dataset, DataLoader
# from omegaconf import DictConfig, OmegaConf
from scipy.spatial.transform import Rotation as R
from models.joints_to_smplx import joints_to_smpl, JointsToSMPLX
from utils import *
from constants import *
from datasets.trumans import TrumansDataset
from models.synhsi import Unet
# from hydra import compose, initialize
import yaml
ACT_TYPE = ['scene', 'grasp', 'artic', 'none']
def convert_trajectory(trajectory):
trajectory_new = [[t['x'], t['y'], t['z']] for t in trajectory]
trajectory_new = np.array(trajectory_new)
return trajectory_new
def get_base_speed(cfg, trajectory, is_zup=True):
trajectory_layer = trajectory
if is_zup:
trajectory_layer = zup_to_yup(trajectory_layer)
trajectory2D = trajectory_layer[:, [0, 2]]
distance = np.sum(np.linalg.norm(trajectory2D[1:] - trajectory2D[:-1], axis=1))
speed = trajectory_layer.shape[0] // distance
print('Base Speed:', speed, flush=True)
return speed
def get_guidance(cfg, trajectory, samplers, act_type='none', speed=35):
trajectory_layer = trajectory
# trajectory_layer = zup_to_yup(trajectory_layer)
print(trajectory_layer.shape)
if cfg.action_type != 'pure_inter':
#TODO
midpoints = trajectory_layer[[0] * cfg.len_pre + list(range(0, len(trajectory_layer), speed)) + [-1] * (cfg.len_act + (1 if cfg.stay_and_act else 0))]
# midpoints[0, 0] = 1.4164
# midpoints[0, 2] = 2.2544
# midpoints = trajectory_layer[[0] * cfg.len_pre + [25, 50] + [50] + [70, 90]]
else:
midpoints = trajectory_layer[[0] * cfg.len_pre + [0] + [0] * (cfg.len_act + 1 if cfg.stay_and_act else 0)]
midpoints = torch.tensor(midpoints).float().to(cfg.device)
max_step = midpoints.shape[0] - 1
mat_init = cfg.batch_size * [np.eye(4)]
mat_init = torch.from_numpy(np.stack(mat_init, axis=0)).float().to(cfg.device)
print(midpoints)
mat_init[:, 0, 3] = midpoints[0, 0]
mat_init[:, 2, 3] = midpoints[0, 2]
dx = midpoints[cfg.len_pre + 1, 0] - midpoints[0, 0]
dz = midpoints[cfg.len_pre + 1, 2] - midpoints[0, 2]
print(-np.arctan2(dx.item(), dz.item()), dx, dz)
mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(dx.item(), dz.item()), 0])).as_matrix()
# mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(-1, -2), 0])).as_matrix()
mat_init[:, :3, :3] = torch.from_numpy(mat_rot_y).float().to(cfg.device)
# goal_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, 3)).float().to(cfg.device)
goal_list = []
action_label_list = []
# action_label_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
sampler_list = []
if act_type == 'none':
for s in range(max_step):
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
goal[:, :] = midpoints[s + 1]
goal_list.append(goal)
if cfg.dataset.nb_actions > 0:
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
action_label_list.append(action_label)
else:
action_label_list.append(None)
sampler_list.append(samplers['body'])
elif act_type == 'write':
midpoints = torch.from_numpy(trajectory_layer[::4]).float().to(cfg.device)
for s in range(midpoints.shape[0] // 16):
goal = torch.zeros((cfg.batch_size, 16, 3)).float().to(cfg.device)
goal[:, :] = midpoints[s * 16: (s + 1) * 16]
goal_list.append(goal)
if cfg.dataset.nb_actions > 0:
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
action_label_list.append(action_label)
else:
action_label_list.append(None)
sampler_list.append(samplers['hand'])
elif act_type == 'scene':
for s in range(max_step):
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
goal[:, :] = midpoints[s + 1]
goal_list.append(goal)
sampler_list.append(samplers['body'])
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
if s > max_step - cfg.len_act:
action_label[:, :, cfg.action_id] = 1.
action_label_list.append(action_label)
elif act_type == 'grasp':
for s in range(max_step):
if s != max_step - cfg.len_act:
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
goal[:, :] = midpoints[s + 1]
goal_list.append(goal)
else:
# grasp_goal = zup_to_yup(np.array([[-0.32, -3.36, 0.395],
# [-0.3, -3.2, 0.395],
# [-0.25, -3, 0.394],
# [-0.147, -3.0, 0.395]])).reshape((cfg.batch_size, -1, 3))
grasp_goal = zup_to_yup(np.array(trajectory['Object'])).reshape((cfg.batch_size, 1, 3))
goal = torch.zeros((cfg.batch_size, 3, 3)).float().to(cfg.device)
goal[:, :] = torch.from_numpy(grasp_goal).float().to(cfg.device)
goal_list.append(goal)
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
if s < max_step - cfg.len_act:
sampler_list.append(samplers['body'])
elif s == max_step - cfg.len_act:
sampler_list.append(samplers['hand'])
else:
sampler_list.append(samplers['body'])
if cfg.action_id != -1:
action_label[:, :, cfg.action_id] = 1.
action_label_list.append(action_label)
elif act_type == 'pure_inter':
sampler_list += [samplers['body']] * (cfg.len_pre + cfg.len_act)
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device)
goal[:, :] = midpoints[0]
goal_list += [goal] * (cfg.len_pre + + cfg.len_act)
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device)
action_label_list += [action_label.clone()] * cfg.len_pre
action_label[:, :, cfg.action_id] = 2.
action_label_list += [action_label.clone()] * cfg.len_act
return mat_init, goal_list, action_label_list, sampler_list
def sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list):
max_step = len(goal_list)
fixed_points = None
fixed_frame = 2
points_all = []
cnt_fixed_frame = 0
cnt_seq_len = 0
for s in range(max_step):
print('step', s)
sampler = sampler_list[s]
if s != 0:
fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat)))
else:
if cfg.continue_last:
method_id = cfg.method_name.split('_')[-1]
method_name_last = cfg.method_name[:-1] + str(int(method_id) - 1)
mat = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_mat.npy'))).to(sampler.device)
fixed_points = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_fixed_points.npy'))).to(sampler.device)
fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat)))
samples, occs = sampler.p_sample_loop(fixed_points, obj_locs, mat, cfg.scene_name, goal_list[s], action_label_list[s])
if 0 <= s < cfg.len_pre:
cnt_fixed_frame += sampler.fixed_frame
if 0 <= s < cfg.len_pre:
cnt_seq_len += cfg.dataset.seq_len
points_gene = samples[-1]
points_gene_np = points_gene.reshape(cfg.batch_size, cfg.dataset.seq_len, -1, 3).cpu().numpy()
if s == 0 or fixed_frame == 0:
#TODO
points_all.append(points_gene_np[:, fixed_frame - 1:])
elif fixed_frame > 0:
points_all.append(points_gene_np[:, fixed_frame:])
# fixed_frame = 0 if s == max_step - 1 else sampler_list[s + 1].fixed_frame
fixed_frame = sampler_list[s].fixed_frame if s == max_step - 1 else sampler_list[s + 1].fixed_frame
pelvis_new = points_gene[:, -fixed_frame, :9].cpu().numpy().reshape(cfg.batch_size, 3, 3)
trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], cfg.batch_size, axis=0)
for ip, pn in enumerate(pelvis_new):
_, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False)
ret_t[1] = 0.0
rot_euler = R.from_matrix(ret_R).as_euler('zxy')
shift_euler = np.array([0, 0, rot_euler[2]])
shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix()
trans_mats[ip, :3, :3] = shift_rot_matrix2
trans_mats[ip, :3, 3] = ret_t.reshape(-1)
mat = torch.from_numpy(trans_mats).to(device=cfg.device, dtype=torch.float32)
if fixed_frame > 0:
fixed_points = points_gene[:, -fixed_frame:]
if s == max_step - 1:
print('Saved Mat and Fixed Points', flush=True)
# np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_mat.npy'), mat.cpu().numpy())
# np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_fixed_points.npy'), fixed_points.cpu().numpy())
points_all = np.concatenate(points_all, axis=1)
points_all = points_all[:, cnt_seq_len - cnt_fixed_frame:]
return points_all
def sample_wrapper(trajectory, obj_locs):
trajectory = convert_trajectory(trajectory)
# obj_locs = {key: [data[key]['x'], data['key']['z']] for key in data.keys() if 'trajectory' not in key}
# cfg = compose(config_name="config_sample_synhsi")
with open('config/config_sample_synhsi.yaml') as f:
cfg = yaml.safe_load(f)
cfg = dotDict(cfg)
# @hydra.main(version_base=None, config_path="../config", config_name="config_sample_synhsi")
# def sample(cfg) -> None:
print(cfg)
# seed_everything(100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_joints_to_smplx = init_model(cfg.model.model_smplx, device=device, eval=True)
model_joints_to_smplx = JointsToSMPLX(**cfg.model.model_smplx)
model_joints_to_smplx.load_state_dict(torch.load(cfg.model.model_smplx.ckpt))
model_joints_to_smplx.to(device)
model_joints_to_smplx.eval()
# model_body = init_model(cfg.model.synhsi_body, device=device, eval=True)
model_body = Unet(**cfg.model.synhsi_body)
model_body.load_state_dict(torch.load(cfg.model.synhsi_body.ckpt))
model_body.to(device)
model_body.eval()
# model_hand = init_model(cfg.model.synhsi_hand, device=device, eval=True)
# synhsi_dataset = hydra.utils.instantiate(cfg.dataset)
synhsi_dataset = TrumansDataset(**cfg.dataset)
sampler_body = hydra.utils.instantiate(cfg.sampler.pelvis)
# sampler_hand = hydra.utils.instantiate(cfg.sampler.right_hand)
sampler_body.set_dataset_and_model(synhsi_dataset, model_body)
# sampler_hand.set_dataset_and_model(None, model_hand)
samplers = {'body': sampler_body, 'hand': None}
# for scene_name in ['N3OpenArea']:
# trajectory = np.load(os.path.join(cfg.test_dir, cfg.exp_name, f'trajectories.npy'), allow_pickle=True).item()
# cfg.scene_name = scene_name
# cfg.action_type = trajectory['action_type']
# if 'action_id' in trajectory.keys():
# cfg.action_id = trajectory['action_id']
# GP_LAYERS = ['GP_Layer']
method_name = cfg.method_name
lid = 0
base_speed = get_base_speed(cfg, trajectory, is_zup=False)
mat, goal_list, action_label_list, sampler_list = get_guidance(cfg, trajectory, samplers, act_type=cfg.action_type,
speed=int(0.6 * base_speed))
points_all = sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list)
# os.makedirs(cfg.exp_dir, exist_ok=True)
vertices = None
for i in range(cfg.batch_size):
keypoint_gene_torch = torch.from_numpy(points_all[i]).reshape(-1, cfg.dataset.nb_joints * 3).to(device)
pose, transl, left_hand, right_hand, vertices = joints_to_smpl(model_joints_to_smplx, keypoint_gene_torch, cfg.dataset.joints_ind, cfg.interp_s)
# output_data = {'transl': transl, 'body_pose': pose[:, 3:], 'global_orient': pose[:, :3],
# 'id': 0}
# print(output_data)
# with open(os.path.join(cfg.exp_dir, f'{method_name}_{lid}_{i}.pkl'), 'wb') as f:
# pkl.dump(output_data, f)
# vertices = np.load('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', allow_pickle=True)
# np.save('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', vertices)
return vertices.tolist()
# v = sample()
#
#
# return v
# def load_dataset_meta(cfg):
# metas = np.load(cfg.)
# if __name__ == '__main__':
# OmegaConf.register_resolver("times_three", times_three)
# OmegaConf.register_new_resolver("times", lambda x, y: int(x) * int(y))
# sample()