Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import torch | |
import numpy as np | |
import skimage | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
PerspectiveCameras, | |
) | |
from .render import render | |
from .ops import project_points, get_pointcloud, merge_pointclouds | |
def downsample_point_cloud(optimization_bundle, device="cpu"): | |
point_cloud = None | |
for i, frame in enumerate(optimization_bundle["frames"]): | |
if frame.get("supporting", False): | |
continue | |
downsampled_image = copy.deepcopy(frame["image"]) | |
downsampled_image.thumbnail((360, 360)) | |
image_size = downsampled_image.size | |
w, h = image_size | |
# regenerate the point cloud at a lower resolution | |
R, T = look_at_view_transform(device=device, azim=frame["azim"], elev=frame["elev"], dist=frame["dist"])#, dist=1+0.15*step) | |
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False) | |
# downsample the depth | |
downsampled_depth = torch.nn.functional.interpolate(torch.tensor(frame["depth"]).unsqueeze(0).unsqueeze(0).float().to(device), size=(h, w), mode="nearest").squeeze() | |
xy_depth_world = project_points(cameras, downsampled_depth) | |
rgb = (torch.from_numpy(np.asarray(downsampled_image).copy()).reshape(-1, 3).float() / 255).to(device) | |
c2w = cameras.get_world_to_view_transform().get_matrix()[0] | |
if i == 0: | |
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb) | |
else: | |
images, masks, depths = render(cameras, point_cloud, radius=1e-2) | |
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams | |
# in theory, 1 pixel is sufficient but we use 2 to be safe | |
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(1))).to(device) | |
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)]) | |
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud]) | |
return point_cloud | |