import copy import torch import numpy as np import skimage from pytorch3d.renderer import ( look_at_view_transform, PerspectiveCameras, ) from .render import render from .ops import project_points, get_pointcloud, merge_pointclouds def downsample_point_cloud(optimization_bundle, device="cpu"): point_cloud = None for i, frame in enumerate(optimization_bundle["frames"]): if frame.get("supporting", False): continue downsampled_image = copy.deepcopy(frame["image"]) downsampled_image.thumbnail((360, 360)) image_size = downsampled_image.size w, h = image_size # regenerate the point cloud at a lower resolution R, T = look_at_view_transform(device=device, azim=frame["azim"], elev=frame["elev"], dist=frame["dist"])#, dist=1+0.15*step) cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False) # downsample the depth downsampled_depth = torch.nn.functional.interpolate(torch.tensor(frame["depth"]).unsqueeze(0).unsqueeze(0).float().to(device), size=(h, w), mode="nearest").squeeze() xy_depth_world = project_points(cameras, downsampled_depth) rgb = (torch.from_numpy(np.asarray(downsampled_image).copy()).reshape(-1, 3).float() / 255).to(device) c2w = cameras.get_world_to_view_transform().get_matrix()[0] if i == 0: point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb) else: images, masks, depths = render(cameras, point_cloud, radius=1e-2) # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams # in theory, 1 pixel is sufficient but we use 2 to be safe masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(1))).to(device) partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)]) point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud]) return point_cloud