Paul Engstler
first commit
84eee5b
import torch
import skimage
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
look_at_view_transform,
FoVOrthographicCameras,
FoVPerspectiveCameras,
PerspectiveCameras,
PointsRasterizationSettings,
PointsRenderer,
PulsarPointsRenderer,
PointsRasterizer,
AlphaCompositor,
NormWeightedCompositor
)
from .ops import nearest_neighbor_fill
from typing import cast, Optional
class PointsRendererWithMasks(PointsRenderer):
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
dists2 = fragments.dists
weights = torch.ones_like(dists2)#1 - dists2 / (r * r)
ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()
weights = weights * ok
fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
weights_prm = weights.permute(0, 3, 1, 2)
images = self.compositor(
fragments_prm,
weights_prm,
point_clouds.features_packed().permute(1, 0),
**kwargs,
)
cumprod = torch.cumprod(1 - weights, dim=-1)
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
masks = fragments.idx.long()[..., 0] >= 0
return images, masks, depths
def render_with_settings(cameras, point_cloud, raster_settings, antialiasing: int = 1):
if antialiasing > 1:
raster_settings.image_size = (raster_settings.image_size[0] * antialiasing, raster_settings.image_size[1] * antialiasing)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
renderer = PointsRendererWithMasks(
rasterizer=rasterizer,
compositor=AlphaCompositor()
)
if antialiasing > 1:
images, masks, depths = renderer(point_cloud)
images = images.permute(0, 3, 1, 2) # NHWC -> NCHW
images = F.avg_pool2d(images, kernel_size=antialiasing, stride=antialiasing)
images = images.permute(0, 2, 3, 1) # NCHW -> NHWC
else:
return renderer(point_cloud)
def render(cameras, point_cloud, fill_point_cloud_holes: bool = False, radius: Optional[float] = None, antialiasing: int = 1):
if fill_point_cloud_holes:
coarse_raster_settings = PointsRasterizationSettings(
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
radius = 1e-2,
points_per_pixel = 1
)
_, coarse_mask, _ = render_with_settings(cameras, point_cloud, coarse_raster_settings)
eroded_coarse_mask = torch.from_numpy(skimage.morphology.binary_erosion(coarse_mask[0].cpu().numpy(), footprint=skimage.morphology.disk(2)))
raster_settings = PointsRasterizationSettings(
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
points_per_pixel = 16
)
# Render the scene
images, masks, depths = render_with_settings(cameras, point_cloud, raster_settings)
holes_in_rendering = masks[0].cpu() ^ eroded_coarse_mask
images[0] = nearest_neighbor_fill(images[0], ~holes_in_rendering, 0)
depths[0] = nearest_neighbor_fill(depths[0], ~holes_in_rendering, 0)
return images, eroded_coarse_mask.unsqueeze(0).to(masks.device), depths
else:
raster_settings = PointsRasterizationSettings(
image_size=(int(cameras.image_size[0, 1]), int(cameras.image_size[0, 0])),
radius = (1 / float(max(cameras.image_size[0, 1], cameras.image_size[0, 0])) * 2.0) if radius is None else radius,
points_per_pixel = 16
)
# Render the scene
return render_with_settings(cameras, point_cloud, raster_settings)