MCC_slim / util /hypersim_dataset.py
hugoycj
Initial commit
cacb27a
raw
history blame
9.86 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import glob
import torch
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.ops import sample_points_from_meshes
from util.hypersim_utils import read_h5py, read_img
def hypersim_collate_fn(batch):
assert len(batch[0]) == 4
return (
FrameData.collate([x[0] for x in batch]),
FrameData.collate([x[1] for x in batch]),
FrameData.collate([x[2] for x in batch]),
[x[2] for x in batch]
)
def is_good_xyz(xyz):
assert len(xyz.shape) == 3
return (torch.isfinite(xyz.sum(axis=2))).sum() > 2000
def get_camera_pos_file_name_from_frame_name(frame_name):
tmp = frame_name.split('/')
tmp[-3] = '_detail'
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
tmp[-1] = 'camera_keyframe_positions.hdf5'
return '/'.join(tmp)
def get_camera_look_at_file_name_from_frame_name(frame_name):
tmp = frame_name.split('/')
tmp[-3] = '_detail'
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
tmp[-1] = 'camera_keyframe_look_at_positions.hdf5'
return '/'.join(tmp)
def get_camera_orientation_file_name_from_frame_name(frame_name):
tmp = frame_name.split('/')
tmp[-3] = '_detail'
tmp[-2] = 'cam_' + tmp[-2].split('_')[2]
tmp[-1] = 'camera_keyframe_orientations.hdf5'
return '/'.join(tmp)
def read_scale_from_frame_name(frame_name):
tmp = frame_name.split('/')
with open('/'.join(tmp[:-3] + ['_detail', 'metadata_scene.csv'])) as f:
for line in f:
items = line.split(',')
return float(items[1])
def random_crop(xyz, img, is_train=True):
assert xyz.shape[0] == img.shape[0]
assert xyz.shape[1] == img.shape[1]
width, height = img.shape[0], img.shape[1]
w = h = min(width, height)
if is_train:
i = torch.randint(0, width - w + 1, size=(1,)).item()
j = torch.randint(0, height - h + 1, size=(1,)).item()
else:
i = (width - w) // 2
j = (height - h) // 2
xyz = xyz[i:i+w, j:j+h]
img = img[i:i+w, j:j+h]
xyz = torch.nn.functional.interpolate(
xyz[None].permute(0, 3, 1, 2), (112, 112),
mode='bilinear',
).permute(0, 2, 3, 1)[0]
img = torch.nn.functional.interpolate(
img[None].permute(0, 3, 1, 2), (224, 224),
mode='bilinear',
).permute(0, 2, 3, 1)[0]
return xyz, img
class HyperSimDataset(torch.utils.data.Dataset):
def __init__(self, args, is_train, is_viz=False, **kwargs):
self.args = args
self.is_train = is_train
self.is_viz = is_viz
self.dataset_split = 'train' if is_train else 'val'
self.scene_names = self.load_scene_names(is_train)
if not is_train:
self.meshes = self.load_meshes()
self.hypersim_gt = self.load_hypersim_gt()
def load_hypersim_gt(self):
gt_filename = 'hypersim_gt_train.pt' if self.dataset_split == 'train' else 'hypersim_gt_val.pt'
print('loading GT file from', gt_filename)
gt = torch.load(gt_filename)
for scene_name in gt.keys():
good = torch.isfinite(gt[scene_name][0].sum(axis=1)) & torch.isfinite(gt[scene_name][1].sum(axis=1))
# Subsample GT to reduce memory usage.
if self.is_train:
good = good & (torch.rand(good.shape) < 0.5)
else:
good = good & (torch.rand(good.shape) < 0.1)
gt[scene_name] = [gt[scene_name][0][good], gt[scene_name][1][good]]
return gt
def load_meshes(self):
return torch.load('all_hypersim_val_meshes.pt')
def load_scene_names(self, is_train):
split = 'train' if is_train else 'test'
scene_names = []
with open(os.path.join(
self.args.hypersim_path,
'evermotion_dataset/analysis/metadata_images_split_scene_v1.csv'),'r') as f:
for line in f:
items = line.split(',')
if items[-1].strip() == split:
scene_names.append(items[0])
scene_names = sorted(list(set(scene_names)))
print(len(scene_names), 'scenes loaded:', scene_names)
return scene_names
def is_corrupted_frame(self, frame):
return (
('ai_003_001' in frame and 'cam_00' in frame)
or ('ai_004_009' in frame and 'cam_01' in frame)
)
def get_hypersim_data(self, index):
for retry in range(1000):
try:
if retry < 10:
scene_name = self.scene_names[index % len(self.scene_names)]
else:
scene_name = random.choice(self.scene_names)
frames = glob.glob(os.path.join(self.args.hypersim_path, scene_name, 'images/scene_cam_*_final_preview/*tonemap*'))
seen_frame = random.choice(frames)
if self.is_corrupted_frame(seen_frame):
continue
seen_data = self.load_frame_data(seen_frame)
if not is_good_xyz(seen_data[0]):
continue
cur_gt = self.hypersim_gt[scene_name]
gt_data = [cur_gt[0], cur_gt[1]]
if self.is_train:
mesh_points = torch.zeros((1,))
else:
mesh_points = sample_points_from_meshes(self.meshes[scene_name], 1000000)
# get camera positions
camera_positions = read_h5py(get_camera_pos_file_name_from_frame_name(seen_frame))
camera_position = camera_positions[int(seen_frame.split('.')[-3])]
# get camera orientations
cam_orientations = read_h5py(get_camera_orientation_file_name_from_frame_name(seen_frame))
cam_orientation = cam_orientations[int(seen_frame.split('.')[-3])]
cam_orientation = cam_orientation * (-1.0)
# rotate to camera direction
seen_data[0] = torch.matmul(seen_data[0], cam_orientation)
gt_data[0] = torch.matmul(gt_data[0], cam_orientation)
# shift to camera center
camera_position = torch.matmul(camera_position, cam_orientation)
seen_data[0] -= camera_position
gt_data[0] -= camera_position
# to meter
asset_to_meter_scale = read_scale_from_frame_name(seen_frame)
seen_data[0] = seen_data[0] * asset_to_meter_scale
gt_data[0] = gt_data[0] * asset_to_meter_scale
# get points GT
n_gt = 30000
in_front_of_cam = (gt_data[0][..., 2] > 0)
if in_front_of_cam.sum() < 1000:
print('Warning! Not enough in front of cam.', in_front_of_cam.sum())
continue
gt_data = [gt_data[0][in_front_of_cam], gt_data[1][in_front_of_cam]]
if in_front_of_cam.sum() < n_gt:
selected = random.choices(range(gt_data[0].shape[0]), k=n_gt)
else:
selected = random.sample(range(gt_data[0].shape[0]), n_gt)
gt_data = [gt_data[0][selected][None], gt_data[1][selected][None], torch.zeros((1,))]
if not self.is_train:
mesh_points = torch.matmul(mesh_points, cam_orientation)
mesh_points -= camera_position * asset_to_meter_scale
in_front_of_cam = (mesh_points[..., 2] > 0)
if in_front_of_cam.sum() < 1000:
print('Warning! Not enough mesh in front of cam.', in_front_of_cam.sum())
continue
mesh_points = mesh_points[in_front_of_cam]
if in_front_of_cam.sum() < n_gt:
selected = random.choices(range(mesh_points.shape[0]), k=n_gt)
else:
selected = random.sample(range(mesh_points.shape[0]), n_gt)
mesh_points = mesh_points[selected][None]
mesh_points[..., 0] *= -1
seen_data[0][..., 0] *= -1
gt_data[0][..., 0] *= -1
seen_data[1] = seen_data[1].permute(2, 0, 1)
return seen_data, gt_data, mesh_points, scene_name
except Exception as e:
print(scene_name, 'loading failed', retry, e)
def __getitem__(self, index):
seen_data, gt_data, mesh_points, scene_name = self.get_hypersim_data(index)
# normalize the data
example_std = get_example_std(seen_data[0])
seen_data[0] = seen_data[0] / example_std
gt_data[0] = gt_data[0] / example_std
mesh_points = mesh_points / example_std
return (
seen_data,
gt_data,
mesh_points,
scene_name,
)
def load_frame_data(self, frame_path):
frame_xyz_path = frame_path.replace('final_preview/', 'geometry_hdf5/').replace('.tonemap.jpg', '.position.hdf5')
xyz = read_h5py(frame_xyz_path)
img = read_img(frame_path)
xyz, img = random_crop(
xyz, img,
is_train=self.is_train,
)
return [xyz, img]
def __len__(self) -> int:
if self.is_train:
return int(len(self.scene_names) * self.args.train_epoch_len_multiplier)
elif self.is_viz:
return len(self.scene_names)
else:
return int(len(self.scene_names) * self.args.eval_epoch_len_multiplier)
def get_example_std(x):
x = x.reshape(-1, 3)
x = x[torch.isfinite(x.sum(dim=1))]
return x.std(dim=0).mean().detach()