import os import torch import numpy as np from scipy.spatial.transform import Rotation as R from torch.utils.data import Dataset, DataLoader, Subset class TrumansDataset(Dataset): def __init__(self, folder, device, mesh_grid, batch_size=1, seq_len=32, step=1, nb_voxels=32, train=True, load_scene=True, load_action=True, no_objects=False, **kwargs): self.device = device self.train = train self.load_scene = load_scene self.load_action = load_action # self.body_pose = np.load(os.path.join(folder, 'human_pose.npy')) # self.transl = np.load(os.path.join(folder, 'human_transl.npy')) # self.global_orient = np.load(os.path.join(folder, 'human_orient.npy')) # self.motion_ind = np.load(os.path.join(folder, 'idx_start.npy')) # self.joints = np.load(os.path.join(folder, 'human_joints.npy')) # self.file_blend = np.load(os.path.join(folder, 'file_blend.npy')) self.seq_len=seq_len self.step = step self.batch_size = batch_size # if self.load_action: # self.action_label = np.load(os.path.join(folder, 'action_label.npy')).astype(np.float32) if self.load_scene: self.mesh_grid = mesh_grid self.nb_voxels = nb_voxels self.no_objects = no_objects self.nb_voxels = nb_voxels self.scene_occ = [] self.scene_dict = {} self.scene_folder = os.path.join(folder, 'Scene') # self.scene_flag = np.load(os.path.join(folder, 'scene_flag.npy')) if not no_objects: # self.object_flag = np.load(os.path.join(folder, 'object_flag.npy')) # self.object_mat = np.load(os.path.join(folder, 'object_mat.npy')) self.object_occ = {} self.object_folder = os.path.join(folder, 'Object') for file in sorted(os.listdir(self.object_folder)): print(f"Loading object occupied coordinates {file}") obj_name = file.replace('.npy', '') self.object_occ[obj_name] = torch.from_numpy(np.load(os.path.join(self.object_folder, file))).to(device) for sid, file in enumerate(sorted(os.listdir(self.scene_folder))): # if scene_name != '' and scene_name not in file: # continue print(f"{sid} Loading Scene Mesh {file}") scene_occ = np.load(os.path.join(self.scene_folder, file)) scene_occ = torch.from_numpy(scene_occ).to(device=device, dtype=bool) self.scene_occ.append(scene_occ) self.scene_dict[file] = sid self.scene_occ = torch.stack(self.scene_occ) self.scene_grid_np = np.array([-3, 0, -4, 3, 2, 4, 300, 100, 400]) self.scene_grid_torch = torch.tensor([-3, 0, -4, 3, 2, 4, 300, 100, 400]).to(device) self.batch_id = torch.linspace(0, batch_size - 1, batch_size).tile((nb_voxels ** 3, 1)).T\ .reshape(-1, 1).to(device=device, dtype=torch.long) self.batch_id_obj = torch.linspace(0, batch_size - 1, batch_size).tile((9000, 1)).T \ .reshape(-1, 1).to(device=device, dtype=torch.long) # TODO CHANGE STEP norm = np.load(os.path.join(folder, 'norm.npy'), allow_pickle=True).item()[f'{seq_len, 3}'] self.min = norm[0].astype(np.float32) self.max = norm[1].astype(np.float32) self.min_torch = torch.tensor(self.min).to(device) self.max_torch = torch.tensor(self.max).to(device) def add_object_points(self, points, occ): points = points.reshape(-1, 3) voxel_size = torch.div(self.scene_grid_torch[3: 6] - self.scene_grid_torch[:3], self.scene_grid_torch[6:]) voxel = torch.div((points - self.scene_grid_torch[:3]), voxel_size) voxel = voxel.to(dtype=torch.long) # voxel = rearrange(voxel, 'b p c -> (b p) c') lb = torch.all(voxel >= 0, dim=-1) ub = torch.all(voxel < self.scene_grid_torch[6:] - 0, dim=-1) in_bound = torch.logical_and(lb, ub) # voxel = torch.cat([self.batch_id_obj, voxel], dim=-1) voxel = voxel[in_bound] occ[0, voxel[:, 0], voxel[:, 1], voxel[:, 2]] = True def get_occ_for_points(self, points, obj_locs, scene_flag): #TODO # points_new = points.reshape(-1, 3) # center_xz = points_new[:, [0, 2]].mean(axis=0) # if torch.norm(center_xz) > 0.: # occ_for_points = torch.load('occ_for_points_at_clear_space.pt').to(points.device) # return occ_for_points if isinstance(scene_flag, str): for k, v in self.scene_dict.items(): if scene_flag in k: scene_flag = [v] break batch_size = points.shape[0] seq_len = points.shape[1] points = points.reshape(-1, 3) voxel_size = torch.div(self.scene_grid_torch[3: 6] - self.scene_grid_torch[:3], self.scene_grid_torch[6:]) voxel = torch.div((points - self.scene_grid_torch[:3]), voxel_size) voxel = voxel.to(dtype=torch.long) # voxel = rearrange(voxel, 'b p c -> (b p) c') lb = torch.all(voxel >= 0, dim=-1) ub = torch.all(voxel < self.scene_grid_torch[6:] - 0, dim=-1) in_bound = torch.logical_and(lb, ub) voxel[torch.logical_not(in_bound)] = 0 voxel = torch.cat([self.batch_id, voxel], dim=1) occ = self.scene_occ[scene_flag] #TODO # occ[:] = False # occ[:, :, 0, :] = True # import cv2 # img = occ[0, :, 10, :].detach().cpu().numpy() # im = np.zeros((300, 400)) # im[img] = 255 # cv2.imwrite('gray.jpg', im.T) if obj_locs: for obj_name, obj_loc in obj_locs.items(): obj_points = self.object_occ[obj_name].clone() obj_points[:, 0] += obj_loc['x'] obj_points[:, 2] += obj_loc['z'] # import pdb # pdb.set_trace() self.add_object_points(obj_points, occ) occ_for_points = occ[voxel[:, 0], voxel[:, 1], voxel[:, 2], voxel[:, 3]] occ_for_points[torch.logical_not(in_bound)] = True occ_for_points = occ_for_points.reshape(batch_size, seq_len, -1) # torch.save(occ_for_points, 'occ_for_points_at_clear_space.pt') # occ_for_points = torch.ones(batch_size, seq_len, 22).to('cuda') return occ_for_points def create_meshgrid(self, batch_size=1): bbox = self.mesh_grid size = (self.nb_voxels, self.nb_voxels, self.nb_voxels) x = torch.linspace(bbox[0], bbox[1], size[0]) y = torch.linspace(bbox[2], bbox[3], size[1]) z = torch.linspace(bbox[4], bbox[5], size[2]) xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') grid = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3) grid = grid.repeat(batch_size, 1, 1) # aug_z = 0.75 + torch.rand(batch_size, 1) * 0.35 # grid[:, :, 2] = grid[:, :, 2] * aug_z return grid @staticmethod def combine_mesh(vert_list, face_list): assert len(vert_list) == len(face_list) verts = None faces = None for v, f in zip(vert_list, face_list): if verts is None: verts = v faces = f else: f = f + verts.shape[0] verts = torch.cat([verts, v]) faces = torch.cat([faces, f]) return verts, faces @staticmethod def transform_mesh(vert_list, trans_mats): assert len(vert_list) == len(trans_mats) vert_list_new = [] for v, m in zip(vert_list, trans_mats): v = v @ m[:3, :3].T + m[:3, 3] vert_list_new.append(v) vert_list_new = torch.stack(vert_list_new) return vert_list_new def __len__(self): return len(self.motion_ind) def normalize(self, data): shape_orig = data.shape data = data.reshape((-1, 3)) # data = (data - self.mean) / self.std data = -1. + 2. * (data - self.min) / (self.max - self.min) data = data.reshape(shape_orig) return data def normalize_torch(self, data): shape_orig = data.shape data = data.reshape((-1, 3)) # data = (data - self.mean) / self.std data = -1. + 2. * (data - self.min_torch) / (self.max_torch - self.min_torch) data = data.reshape(shape_orig) return data def denormalize(self, data): shape_orig = data.shape data = data.reshape((-1, 3)) # data = data * self.std + self.mean data = (data + 1.) * (self.max - self.min) / 2. + self.min data = data.reshape(shape_orig) return data def denormalize_torch(self, data): shape_orig = data.shape data = data.reshape((-1, 3)) # data = data * self.std + self.mean import pdb data = (data + 1.) * (self.max_torch - self.min_torch) / 2. + self.min_torch data = data.reshape(shape_orig) return data