Spaces:
Running
on
L40S
Running
on
L40S
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
# | |
# Modified by Zexin He | |
# The modifications are subject to the same license as the original. | |
""" | |
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. | |
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LearnedVariance(nn.Module): | |
def __init__(self, init_val): | |
super(LearnedVariance, self).__init__() | |
self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val))) | |
def inv_std(self): | |
val = torch.exp(self._inv_std * 10.0) | |
return val | |
def forward(self, x): | |
return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6) | |
class MipRayMarcher2(nn.Module): | |
def __init__(self, activation_factory): | |
super().__init__() | |
self.activation_factory = activation_factory | |
self.variance = LearnedVariance(0.3) | |
self.cos_anneal_ratio = 1.0 | |
def get_alpha(self, sdf, normal, dirs, dists): | |
# sdf: [N 1] normal: [N 3] dirs: [N 3] dists: [N 1] | |
# import ipdb; ipdb.set_trace() | |
inv_std = self.variance(sdf) | |
true_cos = (dirs * normal).sum(-1, keepdim=True) | |
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes | |
# the cos value "not dead" at the beginning training iterations, for better convergence. | |
iter_cos = -( | |
F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) | |
+ F.relu(-true_cos) * self.cos_anneal_ratio | |
) # always non-positive | |
# Estimate signed distances at section points | |
estimated_next_sdf = sdf + iter_cos * dists * 0.5 | |
estimated_prev_sdf = sdf - iter_cos * dists * 0.5 | |
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) | |
next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) | |
p = prev_cdf - next_cdf | |
c = prev_cdf | |
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) | |
return alpha | |
def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): | |
# depths: [B N_ray*N_sample 1] | |
# sdfs: [B, N_ray, N_sample 1] | |
# import ipdb; ipdb.set_trace() | |
deltas = depths[:, :, 1:] - depths[:, :, :-1] | |
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 | |
sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2 | |
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 | |
normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2 | |
# zhaohx add for normal : | |
real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2 | |
# # using factory mode for better usability | |
# densities_mid = self.activation_factory(rendering_options)(densities_mid) | |
# density_delta = densities_mid * deltas | |
# alpha = 1 - torch.exp(-density_delta) | |
# import ipdb; ipdb.set_trace() | |
dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1) | |
B, N_ray, N_sample, _ = sdfs_mid.shape | |
alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1)) | |
alpha = alpha.reshape(B, N_ray, N_sample, -1) | |
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) | |
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] | |
composite_rgb = torch.sum(weights * colors_mid, -2) | |
weight_total = weights.sum(2) | |
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total | |
# clip the composite to min/max range of depths | |
composite_depth = torch.nan_to_num(composite_depth, float('inf')) | |
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) | |
# import pdb; pdb.set_trace() | |
# zhaohx add for normal : | |
composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total | |
composite_normal = torch.nan_to_num(composite_normal, float('inf')) | |
composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals)) | |
if rendering_options.get('white_back', False): | |
# composite_rgb = composite_rgb + 1 - weight_total | |
# weight_total[weight_total < 0.5] = 0 | |
# composite_rgb = composite_rgb * weight_total + 1 - weight_total | |
# now is this | |
if bgcolor is None: | |
composite_rgb = composite_rgb + 1 - weight_total | |
# composite_rgb = composite_rgb * weight_total + 1 - weight_total | |
else: | |
# import pdb; pdb.set_trace() | |
bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1]) | |
composite_rgb = composite_rgb + (1 - weight_total) * bgcolor | |
# composite_rgb = composite_rgb * weight_total + (1 - weight_total) * bgcolor | |
# composite_rgb = composite_rgb | |
# print('new white_back') | |
# rendered value scale is 0-1, comment out original mipnerf scaling | |
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) | |
return composite_rgb, composite_depth, weights, composite_normal | |
def forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): | |
composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals) | |
return composite_rgb, composite_depth, weights, composite_normal | |