Spaces:
Paused
Paused
import os | |
import torch | |
import numpy as np | |
from scipy.spatial.transform import Rotation as R | |
from torch.utils.data import Dataset, DataLoader, Subset | |
class TrumansDataset(Dataset): | |
def __init__(self, folder, device, mesh_grid, batch_size=1, seq_len=32, step=1, nb_voxels=32, train=True, load_scene=True, load_action=True, no_objects=False, **kwargs): | |
self.device = device | |
self.train = train | |
self.load_scene = load_scene | |
self.load_action = load_action | |
# self.body_pose = np.load(os.path.join(folder, 'human_pose.npy')) | |
# self.transl = np.load(os.path.join(folder, 'human_transl.npy')) | |
# self.global_orient = np.load(os.path.join(folder, 'human_orient.npy')) | |
# self.motion_ind = np.load(os.path.join(folder, 'idx_start.npy')) | |
# self.joints = np.load(os.path.join(folder, 'human_joints.npy')) | |
# self.file_blend = np.load(os.path.join(folder, 'file_blend.npy')) | |
self.seq_len=seq_len | |
self.step = step | |
self.batch_size = batch_size | |
# if self.load_action: | |
# self.action_label = np.load(os.path.join(folder, 'action_label.npy')).astype(np.float32) | |
if self.load_scene: | |
self.mesh_grid = mesh_grid | |
self.nb_voxels = nb_voxels | |
self.no_objects = no_objects | |
self.nb_voxels = nb_voxels | |
self.scene_occ = [] | |
self.scene_dict = {} | |
self.scene_folder = os.path.join(folder, 'Scene') | |
# self.scene_flag = np.load(os.path.join(folder, 'scene_flag.npy')) | |
if not no_objects: | |
# self.object_flag = np.load(os.path.join(folder, 'object_flag.npy')) | |
# self.object_mat = np.load(os.path.join(folder, 'object_mat.npy')) | |
self.object_occ = {} | |
self.object_folder = os.path.join(folder, 'Object') | |
for file in sorted(os.listdir(self.object_folder)): | |
print(f"Loading object occupied coordinates {file}") | |
obj_name = file.replace('.npy', '') | |
self.object_occ[obj_name] = torch.from_numpy(np.load(os.path.join(self.object_folder, file))).to(device) | |
for sid, file in enumerate(sorted(os.listdir(self.scene_folder))): | |
# if scene_name != '' and scene_name not in file: | |
# continue | |
print(f"{sid} Loading Scene Mesh {file}") | |
scene_occ = np.load(os.path.join(self.scene_folder, file)) | |
scene_occ = torch.from_numpy(scene_occ).to(device=device, dtype=bool) | |
self.scene_occ.append(scene_occ) | |
self.scene_dict[file] = sid | |
self.scene_occ = torch.stack(self.scene_occ) | |
self.scene_grid_np = np.array([-3, 0, -4, 3, 2, 4, 300, 100, 400]) | |
self.scene_grid_torch = torch.tensor([-3, 0, -4, 3, 2, 4, 300, 100, 400]).to(device) | |
self.batch_id = torch.linspace(0, batch_size - 1, batch_size).tile((nb_voxels ** 3, 1)).T\ | |
.reshape(-1, 1).to(device=device, dtype=torch.long) | |
self.batch_id_obj = torch.linspace(0, batch_size - 1, batch_size).tile((9000, 1)).T \ | |
.reshape(-1, 1).to(device=device, dtype=torch.long) | |
# TODO CHANGE STEP | |
norm = np.load(os.path.join(folder, 'norm.npy'), allow_pickle=True).item()[f'{seq_len, 3}'] | |
self.min = norm[0].astype(np.float32) | |
self.max = norm[1].astype(np.float32) | |
self.min_torch = torch.tensor(self.min).to(device) | |
self.max_torch = torch.tensor(self.max).to(device) | |
def add_object_points(self, points, occ): | |
points = points.reshape(-1, 3) | |
voxel_size = torch.div(self.scene_grid_torch[3: 6] - self.scene_grid_torch[:3], self.scene_grid_torch[6:]) | |
voxel = torch.div((points - self.scene_grid_torch[:3]), voxel_size) | |
voxel = voxel.to(dtype=torch.long) | |
# voxel = rearrange(voxel, 'b p c -> (b p) c') | |
lb = torch.all(voxel >= 0, dim=-1) | |
ub = torch.all(voxel < self.scene_grid_torch[6:] - 0, dim=-1) | |
in_bound = torch.logical_and(lb, ub) | |
# voxel = torch.cat([self.batch_id_obj, voxel], dim=-1) | |
voxel = voxel[in_bound] | |
occ[0, voxel[:, 0], voxel[:, 1], voxel[:, 2]] = True | |
def get_occ_for_points(self, points, obj_locs, scene_flag): | |
#TODO | |
# points_new = points.reshape(-1, 3) | |
# center_xz = points_new[:, [0, 2]].mean(axis=0) | |
# if torch.norm(center_xz) > 0.: | |
# occ_for_points = torch.load('occ_for_points_at_clear_space.pt').to(points.device) | |
# return occ_for_points | |
if isinstance(scene_flag, str): | |
for k, v in self.scene_dict.items(): | |
if scene_flag in k: | |
scene_flag = [v] | |
break | |
batch_size = points.shape[0] | |
seq_len = points.shape[1] | |
points = points.reshape(-1, 3) | |
voxel_size = torch.div(self.scene_grid_torch[3: 6] - self.scene_grid_torch[:3], self.scene_grid_torch[6:]) | |
voxel = torch.div((points - self.scene_grid_torch[:3]), voxel_size) | |
voxel = voxel.to(dtype=torch.long) | |
# voxel = rearrange(voxel, 'b p c -> (b p) c') | |
lb = torch.all(voxel >= 0, dim=-1) | |
ub = torch.all(voxel < self.scene_grid_torch[6:] - 0, dim=-1) | |
in_bound = torch.logical_and(lb, ub) | |
voxel[torch.logical_not(in_bound)] = 0 | |
voxel = torch.cat([self.batch_id, voxel], dim=1) | |
occ = self.scene_occ[scene_flag] | |
#TODO | |
# occ[:] = False | |
# occ[:, :, 0, :] = True | |
# import cv2 | |
# img = occ[0, :, 10, :].detach().cpu().numpy() | |
# im = np.zeros((300, 400)) | |
# im[img] = 255 | |
# cv2.imwrite('gray.jpg', im.T) | |
if obj_locs: | |
for obj_name, obj_loc in obj_locs.items(): | |
obj_points = self.object_occ[obj_name].clone() | |
obj_points[:, 0] += obj_loc['x'] | |
obj_points[:, 2] += obj_loc['z'] | |
# import pdb | |
# pdb.set_trace() | |
self.add_object_points(obj_points, occ) | |
occ_for_points = occ[voxel[:, 0], voxel[:, 1], voxel[:, 2], voxel[:, 3]] | |
occ_for_points[torch.logical_not(in_bound)] = True | |
occ_for_points = occ_for_points.reshape(batch_size, seq_len, -1) | |
# torch.save(occ_for_points, 'occ_for_points_at_clear_space.pt') | |
# occ_for_points = torch.ones(batch_size, seq_len, 22).to('cuda') | |
return occ_for_points | |
def create_meshgrid(self, batch_size=1): | |
bbox = self.mesh_grid | |
size = (self.nb_voxels, self.nb_voxels, self.nb_voxels) | |
x = torch.linspace(bbox[0], bbox[1], size[0]) | |
y = torch.linspace(bbox[2], bbox[3], size[1]) | |
z = torch.linspace(bbox[4], bbox[5], size[2]) | |
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') | |
grid = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3) | |
grid = grid.repeat(batch_size, 1, 1) | |
# aug_z = 0.75 + torch.rand(batch_size, 1) * 0.35 | |
# grid[:, :, 2] = grid[:, :, 2] * aug_z | |
return grid | |
def combine_mesh(vert_list, face_list): | |
assert len(vert_list) == len(face_list) | |
verts = None | |
faces = None | |
for v, f in zip(vert_list, face_list): | |
if verts is None: | |
verts = v | |
faces = f | |
else: | |
f = f + verts.shape[0] | |
verts = torch.cat([verts, v]) | |
faces = torch.cat([faces, f]) | |
return verts, faces | |
def transform_mesh(vert_list, trans_mats): | |
assert len(vert_list) == len(trans_mats) | |
vert_list_new = [] | |
for v, m in zip(vert_list, trans_mats): | |
v = v @ m[:3, :3].T + m[:3, 3] | |
vert_list_new.append(v) | |
vert_list_new = torch.stack(vert_list_new) | |
return vert_list_new | |
def __len__(self): | |
return len(self.motion_ind) | |
def normalize(self, data): | |
shape_orig = data.shape | |
data = data.reshape((-1, 3)) | |
# data = (data - self.mean) / self.std | |
data = -1. + 2. * (data - self.min) / (self.max - self.min) | |
data = data.reshape(shape_orig) | |
return data | |
def normalize_torch(self, data): | |
shape_orig = data.shape | |
data = data.reshape((-1, 3)) | |
# data = (data - self.mean) / self.std | |
data = -1. + 2. * (data - self.min_torch) / (self.max_torch - self.min_torch) | |
data = data.reshape(shape_orig) | |
return data | |
def denormalize(self, data): | |
shape_orig = data.shape | |
data = data.reshape((-1, 3)) | |
# data = data * self.std + self.mean | |
data = (data + 1.) * (self.max - self.min) / 2. + self.min | |
data = data.reshape(shape_orig) | |
return data | |
def denormalize_torch(self, data): | |
shape_orig = data.shape | |
data = data.reshape((-1, 3)) | |
# data = data * self.std + self.mean | |
import pdb | |
data = (data + 1.) * (self.max_torch - self.min_torch) / 2. + self.min_torch | |
data = data.reshape(shape_orig) | |
return data |