MCC_slim / util /co3d_utils.py
hugoycj
Initial commit
cacb27a
raw
history blame
4.02 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
from omegaconf import DictConfig
from typing import Optional
import torch
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
JsonIndexDatasetMapProviderV2
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.io import IO
from pytorch3d.renderer import (
NDCMultinomialRaysampler,
ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
HOLDOUT_CATEGORIES = set([
'apple',
'baseballglove',
'cup',
'ball',
'toyplane',
'handbag',
'book',
'carrot',
'suitcase',
'bowl',
])
def get_dataset_map(
dataset_root: str,
category: str,
subset_name: str,
) -> DatasetMap:
"""
Obtain the dataset map that contains the train/val/test dataset objects.
"""
expand_args_fields(JsonIndexDatasetMapProviderV2)
dataset_map_provider = JsonIndexDatasetMapProviderV2(
category=category,
subset_name=subset_name,
dataset_root=dataset_root,
test_on_train=False,
only_test_set=False,
load_eval_batches=True,
dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
)
return dataset_map_provider.get_dataset_map()
def _load_pointcloud(pcl_path, max_points):
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl
def get_all_dataset_maps(co3d_path, holdout_categories):
all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
all_categories = sorted(all_categories, key=lambda x: hash(x))
# Obtain the CO3Dv2 dataset map
train_dataset_maps = {}
val_dataset_maps = {}
for category in all_categories:
print(f'Loading dataset map ({category})')
dataset_map = {
'train': torch.load(f'dataset_cache/{category}_train.pt'),
'val': torch.load(f'dataset_cache/{category}_val.pt')
}
if not holdout_categories or category not in HOLDOUT_CATEGORIES:
train_dataset_maps[category] = dataset_map['train']
if not holdout_categories or category in HOLDOUT_CATEGORIES:
val_dataset_maps[category] = dataset_map['val']
print('Loaded', len(train_dataset_maps), 'categores for train')
print('Loaded', len(val_dataset_maps), 'categores for val')
return train_dataset_maps, val_dataset_maps
def get_rgbd_points(
imh, imw,
camera: CamerasBase,
depth_map: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5,
) -> Pointclouds:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
pixel colors.
"""
depth_map = torch.nn.functional.interpolate(
depth_map,
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
NDCMultinomialRaysampler(
image_width=imw,
image_height=imh,
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
).squeeze(3)[None]
pts_mask = depth_map > 0.0
if mask is not None:
mask = torch.nn.functional.interpolate(
mask,
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
pts_mask *= mask > mask_thr
pts_3d[~pts_mask] = float('inf')
return pts_3d.squeeze(0).squeeze(0)