Spaces:
Paused
Paused
import os | |
import pdb | |
import pickle as pkl | |
from torch.utils.data import Dataset, DataLoader | |
# from omegaconf import DictConfig, OmegaConf | |
from scipy.spatial.transform import Rotation as R | |
from models.joints_to_smplx import joints_to_smpl, JointsToSMPLX | |
from utils import * | |
from constants import * | |
from datasets.trumans import TrumansDataset | |
from models.synhsi import Unet | |
# from hydra import compose, initialize | |
import yaml | |
ACT_TYPE = ['scene', 'grasp', 'artic', 'none'] | |
def convert_trajectory(trajectory): | |
trajectory_new = [[t['x'], t['y'], t['z']] for t in trajectory] | |
trajectory_new = np.array(trajectory_new) | |
return trajectory_new | |
def get_base_speed(cfg, trajectory, is_zup=True): | |
trajectory_layer = trajectory | |
if is_zup: | |
trajectory_layer = zup_to_yup(trajectory_layer) | |
trajectory2D = trajectory_layer[:, [0, 2]] | |
distance = np.sum(np.linalg.norm(trajectory2D[1:] - trajectory2D[:-1], axis=1)) | |
speed = trajectory_layer.shape[0] // distance | |
print('Base Speed:', speed, flush=True) | |
return speed | |
def get_guidance(cfg, trajectory, samplers, act_type='none', speed=35): | |
trajectory_layer = trajectory | |
# trajectory_layer = zup_to_yup(trajectory_layer) | |
print(trajectory_layer.shape) | |
if cfg.action_type != 'pure_inter': | |
#TODO | |
midpoints = trajectory_layer[[0] * cfg.len_pre + list(range(0, len(trajectory_layer), speed)) + [-1] * (cfg.len_act + (1 if cfg.stay_and_act else 0))] | |
# midpoints[0, 0] = 1.4164 | |
# midpoints[0, 2] = 2.2544 | |
# midpoints = trajectory_layer[[0] * cfg.len_pre + [25, 50] + [50] + [70, 90]] | |
else: | |
midpoints = trajectory_layer[[0] * cfg.len_pre + [0] + [0] * (cfg.len_act + 1 if cfg.stay_and_act else 0)] | |
midpoints = torch.tensor(midpoints).float().to(cfg.device) | |
max_step = midpoints.shape[0] - 1 | |
mat_init = cfg.batch_size * [np.eye(4)] | |
mat_init = torch.from_numpy(np.stack(mat_init, axis=0)).float().to(cfg.device) | |
print(midpoints) | |
mat_init[:, 0, 3] = midpoints[0, 0] | |
mat_init[:, 2, 3] = midpoints[0, 2] | |
dx = midpoints[cfg.len_pre + 1, 0] - midpoints[0, 0] | |
dz = midpoints[cfg.len_pre + 1, 2] - midpoints[0, 2] | |
print(-np.arctan2(dx.item(), dz.item()), dx, dz) | |
mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(dx.item(), dz.item()), 0])).as_matrix() | |
# mat_rot_y = R.from_rotvec(np.array([0, np.arctan2(-1, -2), 0])).as_matrix() | |
mat_init[:, :3, :3] = torch.from_numpy(mat_rot_y).float().to(cfg.device) | |
# goal_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, 3)).float().to(cfg.device) | |
goal_list = [] | |
action_label_list = [] | |
# action_label_list = torch.zeros((max_step, cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
sampler_list = [] | |
if act_type == 'none': | |
for s in range(max_step): | |
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device) | |
goal[:, :] = midpoints[s + 1] | |
goal_list.append(goal) | |
if cfg.dataset.nb_actions > 0: | |
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
action_label_list.append(action_label) | |
else: | |
action_label_list.append(None) | |
sampler_list.append(samplers['body']) | |
elif act_type == 'write': | |
midpoints = torch.from_numpy(trajectory_layer[::4]).float().to(cfg.device) | |
for s in range(midpoints.shape[0] // 16): | |
goal = torch.zeros((cfg.batch_size, 16, 3)).float().to(cfg.device) | |
goal[:, :] = midpoints[s * 16: (s + 1) * 16] | |
goal_list.append(goal) | |
if cfg.dataset.nb_actions > 0: | |
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
action_label_list.append(action_label) | |
else: | |
action_label_list.append(None) | |
sampler_list.append(samplers['hand']) | |
elif act_type == 'scene': | |
for s in range(max_step): | |
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device) | |
goal[:, :] = midpoints[s + 1] | |
goal_list.append(goal) | |
sampler_list.append(samplers['body']) | |
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
if s > max_step - cfg.len_act: | |
action_label[:, :, cfg.action_id] = 1. | |
action_label_list.append(action_label) | |
elif act_type == 'grasp': | |
for s in range(max_step): | |
if s != max_step - cfg.len_act: | |
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device) | |
goal[:, :] = midpoints[s + 1] | |
goal_list.append(goal) | |
else: | |
# grasp_goal = zup_to_yup(np.array([[-0.32, -3.36, 0.395], | |
# [-0.3, -3.2, 0.395], | |
# [-0.25, -3, 0.394], | |
# [-0.147, -3.0, 0.395]])).reshape((cfg.batch_size, -1, 3)) | |
grasp_goal = zup_to_yup(np.array(trajectory['Object'])).reshape((cfg.batch_size, 1, 3)) | |
goal = torch.zeros((cfg.batch_size, 3, 3)).float().to(cfg.device) | |
goal[:, :] = torch.from_numpy(grasp_goal).float().to(cfg.device) | |
goal_list.append(goal) | |
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
if s < max_step - cfg.len_act: | |
sampler_list.append(samplers['body']) | |
elif s == max_step - cfg.len_act: | |
sampler_list.append(samplers['hand']) | |
else: | |
sampler_list.append(samplers['body']) | |
if cfg.action_id != -1: | |
action_label[:, :, cfg.action_id] = 1. | |
action_label_list.append(action_label) | |
elif act_type == 'pure_inter': | |
sampler_list += [samplers['body']] * (cfg.len_pre + cfg.len_act) | |
goal = torch.zeros((cfg.batch_size, 1, 3)).float().to(cfg.device) | |
goal[:, :] = midpoints[0] | |
goal_list += [goal] * (cfg.len_pre + + cfg.len_act) | |
action_label = torch.zeros((cfg.batch_size, cfg.dataset.seq_len, cfg.dataset.nb_actions)).float().to(cfg.device) | |
action_label_list += [action_label.clone()] * cfg.len_pre | |
action_label[:, :, cfg.action_id] = 2. | |
action_label_list += [action_label.clone()] * cfg.len_act | |
return mat_init, goal_list, action_label_list, sampler_list | |
def sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list): | |
max_step = len(goal_list) | |
fixed_points = None | |
fixed_frame = 2 | |
points_all = [] | |
cnt_fixed_frame = 0 | |
cnt_seq_len = 0 | |
for s in range(max_step): | |
print('step', s) | |
sampler = sampler_list[s] | |
if s != 0: | |
fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat))) | |
else: | |
if cfg.continue_last: | |
method_id = cfg.method_name.split('_')[-1] | |
method_name_last = cfg.method_name[:-1] + str(int(method_id) - 1) | |
mat = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_mat.npy'))).to(sampler.device) | |
fixed_points = torch.from_numpy(np.load(os.path.join(cfg.exp_dir, f'{method_name_last}_fixed_points.npy'))).to(sampler.device) | |
fixed_points = sampler.dataset.normalize_torch(transform_points(fixed_points, torch.inverse(mat))) | |
samples, occs = sampler.p_sample_loop(fixed_points, obj_locs, mat, cfg.scene_name, goal_list[s], action_label_list[s]) | |
if 0 <= s < cfg.len_pre: | |
cnt_fixed_frame += sampler.fixed_frame | |
if 0 <= s < cfg.len_pre: | |
cnt_seq_len += cfg.dataset.seq_len | |
points_gene = samples[-1] | |
points_gene_np = points_gene.reshape(cfg.batch_size, cfg.dataset.seq_len, -1, 3).cpu().numpy() | |
if s == 0 or fixed_frame == 0: | |
#TODO | |
points_all.append(points_gene_np[:, fixed_frame - 1:]) | |
elif fixed_frame > 0: | |
points_all.append(points_gene_np[:, fixed_frame:]) | |
# fixed_frame = 0 if s == max_step - 1 else sampler_list[s + 1].fixed_frame | |
fixed_frame = sampler_list[s].fixed_frame if s == max_step - 1 else sampler_list[s + 1].fixed_frame | |
pelvis_new = points_gene[:, -fixed_frame, :9].cpu().numpy().reshape(cfg.batch_size, 3, 3) | |
trans_mats = np.repeat(np.eye(4)[np.newaxis, :, :], cfg.batch_size, axis=0) | |
for ip, pn in enumerate(pelvis_new): | |
_, ret_R, ret_t = rigid_transform_3D(np.matrix(pn), rest_pelvis, False) | |
ret_t[1] = 0.0 | |
rot_euler = R.from_matrix(ret_R).as_euler('zxy') | |
shift_euler = np.array([0, 0, rot_euler[2]]) | |
shift_rot_matrix2 = R.from_euler('zxy', shift_euler).as_matrix() | |
trans_mats[ip, :3, :3] = shift_rot_matrix2 | |
trans_mats[ip, :3, 3] = ret_t.reshape(-1) | |
mat = torch.from_numpy(trans_mats).to(device=cfg.device, dtype=torch.float32) | |
if fixed_frame > 0: | |
fixed_points = points_gene[:, -fixed_frame:] | |
if s == max_step - 1: | |
print('Saved Mat and Fixed Points', flush=True) | |
# np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_mat.npy'), mat.cpu().numpy()) | |
# np.save(os.path.join(cfg.exp_dir, f'{cfg.method_name}_fixed_points.npy'), fixed_points.cpu().numpy()) | |
points_all = np.concatenate(points_all, axis=1) | |
points_all = points_all[:, cnt_seq_len - cnt_fixed_frame:] | |
return points_all | |
def sample_wrapper(trajectory, obj_locs): | |
trajectory = convert_trajectory(trajectory) | |
# obj_locs = {key: [data[key]['x'], data['key']['z']] for key in data.keys() if 'trajectory' not in key} | |
# cfg = compose(config_name="config_sample_synhsi") | |
with open('config/config_sample_synhsi.yaml') as f: | |
cfg = yaml.safe_load(f) | |
cfg = dotDict(cfg) | |
# @hydra.main(version_base=None, config_path="../config", config_name="config_sample_synhsi") | |
# def sample(cfg) -> None: | |
print(cfg) | |
# seed_everything(100) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# model_joints_to_smplx = init_model(cfg.model.model_smplx, device=device, eval=True) | |
model_joints_to_smplx = JointsToSMPLX(**cfg.model.model_smplx) | |
model_joints_to_smplx.load_state_dict(torch.load(cfg.model.model_smplx.ckpt)) | |
model_joints_to_smplx.to(device) | |
model_joints_to_smplx.eval() | |
# model_body = init_model(cfg.model.synhsi_body, device=device, eval=True) | |
model_body = Unet(**cfg.model.synhsi_body) | |
model_body.load_state_dict(torch.load(cfg.model.synhsi_body.ckpt)) | |
model_body.to(device) | |
model_body.eval() | |
# model_hand = init_model(cfg.model.synhsi_hand, device=device, eval=True) | |
# synhsi_dataset = hydra.utils.instantiate(cfg.dataset) | |
synhsi_dataset = TrumansDataset(**cfg.dataset) | |
sampler_body = hydra.utils.instantiate(cfg.sampler.pelvis) | |
# sampler_hand = hydra.utils.instantiate(cfg.sampler.right_hand) | |
sampler_body.set_dataset_and_model(synhsi_dataset, model_body) | |
# sampler_hand.set_dataset_and_model(None, model_hand) | |
samplers = {'body': sampler_body, 'hand': None} | |
# for scene_name in ['N3OpenArea']: | |
# trajectory = np.load(os.path.join(cfg.test_dir, cfg.exp_name, f'trajectories.npy'), allow_pickle=True).item() | |
# cfg.scene_name = scene_name | |
# cfg.action_type = trajectory['action_type'] | |
# if 'action_id' in trajectory.keys(): | |
# cfg.action_id = trajectory['action_id'] | |
# GP_LAYERS = ['GP_Layer'] | |
method_name = cfg.method_name | |
lid = 0 | |
base_speed = get_base_speed(cfg, trajectory, is_zup=False) | |
mat, goal_list, action_label_list, sampler_list = get_guidance(cfg, trajectory, samplers, act_type=cfg.action_type, | |
speed=int(0.6 * base_speed)) | |
points_all = sample_step(cfg, mat, obj_locs, goal_list, action_label_list, sampler_list) | |
# os.makedirs(cfg.exp_dir, exist_ok=True) | |
vertices = None | |
for i in range(cfg.batch_size): | |
keypoint_gene_torch = torch.from_numpy(points_all[i]).reshape(-1, cfg.dataset.nb_joints * 3).to(device) | |
pose, transl, left_hand, right_hand, vertices = joints_to_smpl(model_joints_to_smplx, keypoint_gene_torch, cfg.dataset.joints_ind, cfg.interp_s) | |
# output_data = {'transl': transl, 'body_pose': pose[:, 3:], 'global_orient': pose[:, :3], | |
# 'id': 0} | |
# print(output_data) | |
# with open(os.path.join(cfg.exp_dir, f'{method_name}_{lid}_{i}.pkl'), 'wb') as f: | |
# pkl.dump(output_data, f) | |
# vertices = np.load('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', allow_pickle=True) | |
# np.save('/home/jiangnan/SyntheticHSI/Gradio_demo/vertices.npy', vertices) | |
return vertices.tolist() | |
# v = sample() | |
# | |
# | |
# return v | |
# def load_dataset_meta(cfg): | |
# metas = np.load(cfg.) | |
# if __name__ == '__main__': | |
# OmegaConf.register_resolver("times_three", times_three) | |
# OmegaConf.register_new_resolver("times", lambda x, y: int(x) * int(y)) | |
# sample() | |