|
|
|
"""Contains the function to march rays (integration).""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
__all__ = ['Integrator'] |
|
|
|
|
|
class Integrator(torch.nn.Module): |
|
"""Defines the class to help march rays, i.e. do integral along each ray. |
|
|
|
The ray marcher takes the raw output of the implicit representation |
|
(including colors(i.e. rgbs) and densities(i.e. sigmas)) and uses the |
|
volume rendering equation to produce composited colors and depths. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def integration(self, rgbs, sigmas, depths, rendering_options): |
|
"""Integrate the values along the ray. |
|
|
|
`N` denotes batch size. |
|
`R` denotes the number of rays, equals `H * W`. |
|
`K` denotes the number of points on each ray. |
|
|
|
Args: |
|
rgbs (torch.tensor): colors' value of each point in the fields, with |
|
shape [N, R, K, 3]. |
|
sigmas (torch.tensor): densities' value of each point in the fields, |
|
with shape [N, R, K, 1]. |
|
depths (torch.tensor): depths' value of each point in the fields, |
|
with shape [N, R, K, 1]. |
|
rendering_options (dict): Additional keyword arguments of rendering |
|
option. |
|
|
|
Returns: |
|
A dictionary, containing |
|
- `composite_rgb`: camera radius w.r.t. the world coordinate |
|
system, with shape [N, R, 3]. |
|
- `composite_depth`: camera polar w.r.t. the world coordinate |
|
system, with shape [N, R, 1]. |
|
- `weights`: importance weights of each point in the field, |
|
with shape [N, R, K, 1]. |
|
""" |
|
num_dims = rgbs.ndim |
|
assert num_dims == 4 |
|
assert sigmas.ndim == num_dims and depths.ndim == num_dims |
|
|
|
N, R, K = rgbs.shape[:3] |
|
|
|
|
|
deltas = depths[:, :, 1:] - depths[:, :, :-1] |
|
if rendering_options.get('use_max_depth', False): |
|
max_depth = rendering_options.get('max_depth', None) |
|
if max_depth is not None: |
|
delta_inf = max_depth - deltas[:, :, -1:] |
|
else: |
|
delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1]) |
|
deltas = torch.cat([deltas, delta_inf], -2) |
|
if rendering_options.get('no_dist', False): |
|
deltas[:] = 1 |
|
|
|
use_mid_point = rendering_options.get('use_mid_point', True) |
|
if use_mid_point: |
|
rgbs = (rgbs[:, :, :-1] + rgbs[:, :, 1:]) / 2 |
|
sigmas = (sigmas[:, :, :-1] + sigmas[:, :, 1:]) / 2 |
|
depths = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 |
|
|
|
clamp_mode = rendering_options.get('clamp_mode', 'mipnerf') |
|
if clamp_mode == 'softplus': |
|
sigmas = F.softplus(sigmas) |
|
elif clamp_mode == 'relu': |
|
sigmas = F.relu(sigmas) |
|
elif clamp_mode == 'mipnerf': |
|
sigmas = F.softplus(sigmas - 1) |
|
else: |
|
raise ValueError(f'Invalid clamping mode: `{clamp_mode}`!\n') |
|
|
|
alphas = 1 - torch.exp(- deltas * sigmas) |
|
alphas_shifted = torch.cat( |
|
[torch.ones_like(alphas[:, :, :1]), 1 - alphas + 1e-10], -2) |
|
weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1] |
|
weights_sum = weights.sum(2) |
|
if rendering_options.get('last_back', False): |
|
weights[:, :, -1] = weights[:, :, -1] + (1 - weights_sum) |
|
|
|
composite_rgb = torch.sum(weights * rgbs, -2) |
|
composite_depth = torch.sum(weights * depths, -2) |
|
|
|
if rendering_options.get('normalize_rgb', False): |
|
composite_rgb = composite_rgb / weights_sum |
|
if rendering_options.get('normalize_depth', True): |
|
composite_depth = composite_depth / weights_sum |
|
if rendering_options.get('clip_depth', True): |
|
composite_depth = torch.nan_to_num(composite_depth, float('inf')) |
|
composite_depth = torch.clip(composite_depth, torch.min(depths), |
|
torch.max(depths)) |
|
|
|
if rendering_options.get('white_back', False): |
|
composite_rgb = composite_rgb + 1 - weights_sum |
|
|
|
composite_rgb = composite_rgb * 2 - 1 |
|
|
|
results = { |
|
'composite_rgb': composite_rgb, |
|
'composite_depth': composite_depth, |
|
'weights': weights |
|
} |
|
|
|
return results |
|
|
|
def forward(self, rgbs, sigmas, depths, rendering_options): |
|
results = self.integration(rgbs, sigmas, depths, rendering_options) |
|
return results |
|
|