File size: 2,282 Bytes
84eee5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import copy
import torch
import numpy as np

import skimage
from pytorch3d.renderer import (
    look_at_view_transform,
    PerspectiveCameras,
)

from .render import render
from .ops import project_points, get_pointcloud, merge_pointclouds

def downsample_point_cloud(optimization_bundle, device="cpu"):
    point_cloud = None

    for i, frame in enumerate(optimization_bundle["frames"]):
        if frame.get("supporting", False):
            continue

        downsampled_image = copy.deepcopy(frame["image"])
        downsampled_image.thumbnail((360, 360))

        image_size = downsampled_image.size
        w, h = image_size

        # regenerate the point cloud at a lower resolution
        R, T = look_at_view_transform(device=device, azim=frame["azim"], elev=frame["elev"], dist=frame["dist"])#, dist=1+0.15*step)
        cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)

        # downsample the depth
        downsampled_depth = torch.nn.functional.interpolate(torch.tensor(frame["depth"]).unsqueeze(0).unsqueeze(0).float().to(device), size=(h, w), mode="nearest").squeeze()

        xy_depth_world = project_points(cameras, downsampled_depth)

        rgb = (torch.from_numpy(np.asarray(downsampled_image).copy()).reshape(-1, 3).float() / 255).to(device)

        c2w = cameras.get_world_to_view_transform().get_matrix()[0]

        if i == 0:
            point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)

        else:
            images, masks, depths = render(cameras, point_cloud, radius=1e-2)

            # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
            # in theory, 1 pixel is sufficient but we use 2 to be safe
            masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(1))).to(device)

            partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])

            point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
        
    return point_cloud