|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
from omegaconf import DictConfig |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap |
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( |
|
JsonIndexDatasetMapProviderV2 |
|
) |
|
from pytorch3d.implicitron.tools.config import expand_args_fields |
|
from pytorch3d.io import IO |
|
from pytorch3d.renderer import ( |
|
NDCMultinomialRaysampler, |
|
ray_bundle_to_ray_points, |
|
) |
|
from pytorch3d.renderer.cameras import CamerasBase |
|
from pytorch3d.structures import Pointclouds |
|
|
|
|
|
HOLDOUT_CATEGORIES = set([ |
|
'apple', |
|
'baseballglove', |
|
'cup', |
|
'ball', |
|
'toyplane', |
|
'handbag', |
|
'book', |
|
'carrot', |
|
'suitcase', |
|
'bowl', |
|
]) |
|
|
|
def get_dataset_map( |
|
dataset_root: str, |
|
category: str, |
|
subset_name: str, |
|
) -> DatasetMap: |
|
""" |
|
Obtain the dataset map that contains the train/val/test dataset objects. |
|
""" |
|
expand_args_fields(JsonIndexDatasetMapProviderV2) |
|
dataset_map_provider = JsonIndexDatasetMapProviderV2( |
|
category=category, |
|
subset_name=subset_name, |
|
dataset_root=dataset_root, |
|
test_on_train=False, |
|
only_test_set=False, |
|
load_eval_batches=True, |
|
dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}), |
|
) |
|
return dataset_map_provider.get_dataset_map() |
|
|
|
|
|
def _load_pointcloud(pcl_path, max_points): |
|
pcl = IO().load_pointcloud(pcl_path) |
|
if max_points > 0: |
|
pcl = pcl.subsample(max_points) |
|
|
|
return pcl |
|
|
|
|
|
def get_all_dataset_maps(co3d_path, holdout_categories): |
|
all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')] |
|
all_categories = sorted(all_categories, key=lambda x: hash(x)) |
|
|
|
|
|
train_dataset_maps = {} |
|
val_dataset_maps = {} |
|
for category in all_categories: |
|
|
|
print(f'Loading dataset map ({category})') |
|
dataset_map = { |
|
'train': torch.load(f'dataset_cache/{category}_train.pt'), |
|
'val': torch.load(f'dataset_cache/{category}_val.pt') |
|
} |
|
if not holdout_categories or category not in HOLDOUT_CATEGORIES: |
|
train_dataset_maps[category] = dataset_map['train'] |
|
if not holdout_categories or category in HOLDOUT_CATEGORIES: |
|
val_dataset_maps[category] = dataset_map['val'] |
|
|
|
print('Loaded', len(train_dataset_maps), 'categores for train') |
|
print('Loaded', len(val_dataset_maps), 'categores for val') |
|
return train_dataset_maps, val_dataset_maps |
|
|
|
|
|
def get_rgbd_points( |
|
imh, imw, |
|
camera: CamerasBase, |
|
depth_map: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
mask_thr: float = 0.5, |
|
) -> Pointclouds: |
|
""" |
|
Given a batch of images, depths, masks and cameras, generate a colored |
|
point cloud by unprojecting depth maps to the and coloring with the source |
|
pixel colors. |
|
""" |
|
depth_map = torch.nn.functional.interpolate( |
|
depth_map, |
|
size=[imh, imw], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
pts_3d = ray_bundle_to_ray_points( |
|
NDCMultinomialRaysampler( |
|
image_width=imw, |
|
image_height=imh, |
|
n_pts_per_ray=1, |
|
min_depth=1.0, |
|
max_depth=1.0, |
|
)(camera)._replace(lengths=depth_map[:, 0, ..., None]) |
|
).squeeze(3)[None] |
|
|
|
pts_mask = depth_map > 0.0 |
|
if mask is not None: |
|
mask = torch.nn.functional.interpolate( |
|
mask, |
|
size=[imh, imw], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
pts_mask *= mask > mask_thr |
|
pts_3d[~pts_mask] = float('inf') |
|
return pts_3d.squeeze(0).squeeze(0) |
|
|
|
|