|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
rng = np.random.RandomState(234) |
|
|
|
|
|
|
|
def volume_sampling(sample_pts, features, aabb): |
|
B, C, D, W, H = features.shape |
|
assert B == 1 |
|
aabb = torch.Tensor(aabb).to(sample_pts.device) |
|
N_rays, N_samples, coords = sample_pts.shape |
|
sample_pts = sample_pts.view(1, N_rays * N_samples, 1, 1, |
|
3).repeat(B, 1, 1, 1, 1) |
|
aabbSize = aabb[1] - aabb[0] |
|
invgridSize = 1.0 / aabbSize * 2 |
|
norm_pts = (sample_pts - aabb[0]) * invgridSize - 1 |
|
sample_features = F.grid_sample( |
|
features, norm_pts, align_corners=True, padding_mode='border') |
|
masks = ((norm_pts < 1) & (norm_pts > -1)).float().sum(dim=-1) |
|
masks = (masks.view(N_rays, N_samples) == 3) |
|
return sample_features.view(C, N_rays, |
|
N_samples).permute(1, 2, 0).contiguous(), masks |
|
|
|
|
|
def _compute_projection(img_meta): |
|
views = len(img_meta['lidar2img']['extrinsic']) |
|
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:4, :4]) |
|
ratio = img_meta['ori_shape'][0] / img_meta['img_shape'][0] |
|
intrinsic[:2] /= ratio |
|
intrinsic = intrinsic.unsqueeze(0).view(1, 16).repeat(views, 1) |
|
img_size = torch.Tensor(img_meta['img_shape'][:2]).to(intrinsic.device) |
|
img_size = img_size.unsqueeze(0).repeat(views, 1) |
|
extrinsics = [] |
|
for v in range(views): |
|
extrinsics.append( |
|
torch.Tensor(img_meta['lidar2img']['extrinsic'][v]).to( |
|
intrinsic.device)) |
|
extrinsic = torch.stack(extrinsics).view(views, 16) |
|
train_cameras = torch.cat([img_size, intrinsic, extrinsic], dim=-1) |
|
return train_cameras.unsqueeze(0) |
|
|
|
|
|
def compute_mask_points(feature, mask): |
|
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) |
|
mean = torch.sum(feature * weight, dim=2, keepdim=True) |
|
var = torch.sum((feature - mean)**2, dim=2, keepdim=True) |
|
var = var / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) |
|
var = torch.exp(-var) |
|
return mean, var |
|
|
|
|
|
def sample_pdf(bins, weights, N_samples, det=False): |
|
"""Helper function used for sampling. |
|
|
|
Args: |
|
bins (tensor):Tensor of shape [N_rays, M+1], M is the number of bins |
|
weights (tensor):Tensor of shape [N_rays, M+1], M is the number of bins |
|
N_samples (int):Number of samples along each ray |
|
det (bool):If True, will perform deterministic sampling |
|
|
|
Returns: |
|
samples (tuple): [N_rays, N_samples] |
|
""" |
|
|
|
M = weights.shape[1] |
|
weights += 1e-5 |
|
|
|
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) |
|
cdf = torch.cumsum(pdf, dim=-1) |
|
cdf = torch.cat([torch.zeros_like(cdf[:, 0:1]), cdf], dim=-1) |
|
|
|
|
|
if det: |
|
u = torch.linspace(0., 1., N_samples, device=bins.device) |
|
u = u.unsqueeze(0).repeat(bins.shape[0], 1) |
|
else: |
|
u = torch.rand(bins.shape[0], N_samples, device=bins.device) |
|
|
|
|
|
above_inds = torch.zeros_like(u, dtype=torch.long) |
|
for i in range(M): |
|
above_inds += (u >= cdf[:, i:i + 1]).long() |
|
|
|
|
|
below_inds = torch.clamp(above_inds - 1, min=0) |
|
inds_g = torch.stack((below_inds, above_inds), dim=2) |
|
|
|
cdf = cdf.unsqueeze(1).repeat(1, N_samples, 1) |
|
cdf_g = torch.gather(input=cdf, dim=-1, index=inds_g) |
|
|
|
bins = bins.unsqueeze(1).repeat(1, N_samples, 1) |
|
bins_g = torch.gather(input=bins, dim=-1, index=inds_g) |
|
|
|
denom = cdf_g[:, :, 1] - cdf_g[:, :, 0] |
|
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) |
|
t = (u - cdf_g[:, :, 0]) / denom |
|
|
|
samples = bins_g[:, :, 0] + t * (bins_g[:, :, 1] - bins_g[:, :, 0]) |
|
|
|
return samples |
|
|
|
|
|
def sample_along_camera_ray(ray_o, |
|
ray_d, |
|
depth_range, |
|
N_samples, |
|
inv_uniform=False, |
|
det=False): |
|
"""Sampling along the camera ray. |
|
|
|
Args: |
|
ray_o (tensor): Origin of the ray in scene coordinate system; |
|
tensor of shape [N_rays, 3] |
|
ray_d (tensor): Homogeneous ray direction vectors in |
|
scene coordinate system; tensor of shape [N_rays, 3] |
|
depth_range (tuple): [near_depth, far_depth] |
|
inv_uniform (bool): If True,uniformly sampling inverse depth. |
|
det (bool): If True, will perform deterministic sampling. |
|
Returns: |
|
pts (tensor): Tensor of shape [N_rays, N_samples, 3] |
|
z_vals (tensor): Tensor of shape [N_rays, N_samples] |
|
""" |
|
|
|
|
|
near_depth_value = depth_range[0] |
|
far_depth_value = depth_range[1] |
|
assert near_depth_value > 0 and far_depth_value > 0 \ |
|
and far_depth_value > near_depth_value |
|
|
|
near_depth = near_depth_value * torch.ones_like(ray_d[..., 0]) |
|
|
|
far_depth = far_depth_value * torch.ones_like(ray_d[..., 0]) |
|
|
|
if inv_uniform: |
|
start = 1. / near_depth |
|
step = (1. / far_depth - start) / (N_samples - 1) |
|
inv_z_vals = torch.stack([start + i * step for i in range(N_samples)], |
|
dim=1) |
|
z_vals = 1. / inv_z_vals |
|
else: |
|
start = near_depth |
|
step = (far_depth - near_depth) / (N_samples - 1) |
|
z_vals = torch.stack([start + i * step for i in range(N_samples)], |
|
dim=1) |
|
|
|
if not det: |
|
|
|
mids = .5 * (z_vals[:, 1:] + z_vals[:, :-1]) |
|
upper = torch.cat([mids, z_vals[:, -1:]], dim=-1) |
|
lower = torch.cat([z_vals[:, 0:1], mids], dim=-1) |
|
|
|
t_rand = torch.rand_like(z_vals) |
|
z_vals = lower + (upper - lower) * t_rand |
|
|
|
ray_d = ray_d.unsqueeze(1).repeat(1, N_samples, 1) |
|
ray_o = ray_o.unsqueeze(1).repeat(1, N_samples, 1) |
|
pts = z_vals.unsqueeze(2) * ray_d + ray_o |
|
return pts, z_vals |
|
|
|
|
|
|
|
def raw2outputs(raw, z_vals, mask, white_bkgd=False): |
|
"""Transform raw data to outputs: |
|
|
|
Args: |
|
raw(tensor):Raw network output.Tensor of shape [N_rays, N_samples, 4] |
|
z_vals(tensor):Depth of point samples along rays. |
|
Tensor of shape [N_rays, N_samples] |
|
ray_d(tensor):[N_rays, 3] |
|
|
|
Returns: |
|
ret(dict): |
|
-rgb(tensor):[N_rays, 3] |
|
-depth(tensor):[N_rays,] |
|
-weights(tensor):[N_rays,] |
|
-depth_std(tensor):[N_rays,] |
|
""" |
|
rgb = raw[:, :, :3] |
|
sigma = raw[:, :, 3] |
|
|
|
|
|
|
|
|
|
|
|
|
|
sigma2alpha = lambda sigma, dists: 1. - torch.exp(-sigma) |
|
|
|
|
|
|
|
dists = z_vals[:, 1:] - z_vals[:, :-1] |
|
dists = torch.cat((dists, dists[:, -1:]), dim=-1) |
|
|
|
alpha = sigma2alpha(sigma, dists) |
|
|
|
T = torch.cumprod(1. - alpha + 1e-10, dim=-1)[:, :-1] |
|
T = torch.cat((torch.ones_like(T[:, 0:1]), T), dim=-1) |
|
|
|
|
|
|
|
weights = alpha * T |
|
rgb_map = torch.sum(weights.unsqueeze(2) * rgb, dim=1) |
|
|
|
if white_bkgd: |
|
rgb_map = rgb_map + (1. - torch.sum(weights, dim=-1, keepdim=True)) |
|
|
|
if mask is not None: |
|
mask = mask.float().sum(dim=1) > 8 |
|
|
|
depth_map = torch.sum( |
|
weights * z_vals, dim=-1) / ( |
|
torch.sum(weights, dim=-1) + 1e-8) |
|
depth_map = torch.clamp(depth_map, z_vals.min(), z_vals.max()) |
|
|
|
ret = OrderedDict([('rgb', rgb_map), ('depth', depth_map), |
|
('weights', weights), ('mask', mask), ('alpha', alpha), |
|
('z_vals', z_vals), ('transparency', T)]) |
|
|
|
return ret |
|
|
|
|
|
def render_rays_func( |
|
ray_o, |
|
ray_d, |
|
mean_volume, |
|
cov_volume, |
|
features_2D, |
|
img, |
|
aabb, |
|
near_far_range, |
|
N_samples, |
|
N_rand=4096, |
|
nerf_mlp=None, |
|
img_meta=None, |
|
projector=None, |
|
mode='volume', |
|
nerf_sample_view=3, |
|
inv_uniform=False, |
|
N_importance=0, |
|
det=False, |
|
is_train=True, |
|
white_bkgd=False, |
|
gt_rgb=None, |
|
gt_depth=None): |
|
|
|
ret = { |
|
'outputs_coarse': None, |
|
'outputs_fine': None, |
|
'gt_rgb': gt_rgb, |
|
'gt_depth': gt_depth |
|
} |
|
|
|
|
|
|
|
pts, z_vals = sample_along_camera_ray( |
|
ray_o=ray_o, |
|
ray_d=ray_d, |
|
depth_range=near_far_range, |
|
N_samples=N_samples, |
|
inv_uniform=inv_uniform, |
|
det=det) |
|
N_rays, N_samples = pts.shape[:2] |
|
|
|
if mode == 'image': |
|
img = img.permute(0, 2, 3, 1).unsqueeze(0) |
|
train_camera = _compute_projection(img_meta).to(img.device) |
|
rgb_feat, mask = projector.compute( |
|
pts, img, train_camera, features_2D, grid_sample=True) |
|
pixel_mask = mask[..., 0].sum(dim=2) > 1 |
|
mean, var = compute_mask_points(rgb_feat, mask) |
|
globalfeat = torch.cat([mean, var], dim=-1).squeeze(2) |
|
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalfeat) |
|
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1) |
|
ret['sigma'] = density_pts |
|
|
|
elif mode == 'volume': |
|
mean_pts, inbound_masks = volume_sampling(pts, mean_volume, aabb) |
|
cov_pts, inbound_masks = volume_sampling(pts, cov_volume, aabb) |
|
|
|
img = img.permute(0, 2, 3, 1).unsqueeze(0) |
|
train_camera = _compute_projection(img_meta).to(img.device) |
|
_, view_mask = projector.compute(pts, img, train_camera, None) |
|
pixel_mask = view_mask[..., 0].sum(dim=2) > 1 |
|
|
|
|
|
|
|
globalpts = torch.cat([mean_pts, cov_pts], dim=-1) |
|
rgb_pts, density_pts = nerf_mlp(pts, ray_d, globalpts) |
|
density_pts = density_pts * inbound_masks.unsqueeze(dim=-1) |
|
|
|
raw_coarse = torch.cat([rgb_pts, density_pts], dim=-1) |
|
|
|
outputs_coarse = raw2outputs( |
|
raw_coarse, z_vals, pixel_mask, white_bkgd=white_bkgd) |
|
ret['outputs_coarse'] = outputs_coarse |
|
|
|
return ret |
|
|
|
|
|
def render_rays( |
|
ray_batch, |
|
mean_volume, |
|
cov_volume, |
|
features_2D, |
|
img, |
|
aabb, |
|
near_far_range, |
|
N_samples, |
|
N_rand=4096, |
|
nerf_mlp=None, |
|
img_meta=None, |
|
projector=None, |
|
mode='volume', |
|
nerf_sample_view=3, |
|
inv_uniform=False, |
|
N_importance=0, |
|
det=False, |
|
is_train=True, |
|
white_bkgd=False, |
|
render_testing=False): |
|
"""The function of the nerf rendering.""" |
|
|
|
ray_o = ray_batch['ray_o'] |
|
ray_d = ray_batch['ray_d'] |
|
gt_rgb = ray_batch['gt_rgb'] |
|
gt_depth = ray_batch['gt_depth'] |
|
nerf_sizes = ray_batch['nerf_sizes'] |
|
if is_train: |
|
ray_o = ray_o.view(-1, 3) |
|
ray_d = ray_d.view(-1, 3) |
|
gt_rgb = gt_rgb.view(-1, 3) |
|
if gt_depth.shape[1] != 0: |
|
gt_depth = gt_depth.view(-1, 1) |
|
non_zero_depth = (gt_depth > 0).squeeze(-1) |
|
ray_o = ray_o[non_zero_depth] |
|
ray_d = ray_d[non_zero_depth] |
|
gt_rgb = gt_rgb[non_zero_depth] |
|
gt_depth = gt_depth[non_zero_depth] |
|
else: |
|
gt_depth = None |
|
total_rays = ray_d.shape[0] |
|
select_inds = rng.choice(total_rays, size=(N_rand, ), replace=False) |
|
ray_o = ray_o[select_inds] |
|
ray_d = ray_d[select_inds] |
|
gt_rgb = gt_rgb[select_inds] |
|
if gt_depth is not None: |
|
gt_depth = gt_depth[select_inds] |
|
|
|
rets = render_rays_func( |
|
ray_o, |
|
ray_d, |
|
mean_volume, |
|
cov_volume, |
|
features_2D, |
|
img, |
|
aabb, |
|
near_far_range, |
|
N_samples, |
|
N_rand, |
|
nerf_mlp, |
|
img_meta, |
|
projector, |
|
mode, |
|
nerf_sample_view, |
|
inv_uniform, |
|
N_importance, |
|
det, |
|
is_train, |
|
white_bkgd, |
|
gt_rgb, |
|
gt_depth) |
|
|
|
elif render_testing: |
|
nerf_size = nerf_sizes[0] |
|
view_num = ray_o.shape[1] |
|
H = nerf_size[0][0] |
|
W = nerf_size[0][1] |
|
ray_o = ray_o.view(-1, 3) |
|
ray_d = ray_d.view(-1, 3) |
|
gt_rgb = gt_rgb.view(-1, 3) |
|
print(gt_rgb.shape) |
|
if len(gt_depth) != 0: |
|
gt_depth = gt_depth.view(-1, 1) |
|
else: |
|
gt_depth = None |
|
assert view_num * H * W == ray_o.shape[0] |
|
num_rays = ray_o.shape[0] |
|
results = [] |
|
rgbs = [] |
|
for i in range(0, num_rays, N_rand): |
|
ray_o_chunck = ray_o[i:i + N_rand, :] |
|
ray_d_chunck = ray_d[i:i + N_rand, :] |
|
|
|
ret = render_rays_func(ray_o_chunck, ray_d_chunck, mean_volume, |
|
cov_volume, features_2D, img, aabb, |
|
near_far_range, N_samples, N_rand, nerf_mlp, |
|
img_meta, projector, mode, nerf_sample_view, |
|
inv_uniform, N_importance, True, is_train, |
|
white_bkgd, gt_rgb, gt_depth) |
|
results.append(ret) |
|
|
|
rgbs = [] |
|
depths = [] |
|
|
|
if results[0]['outputs_coarse'] is not None: |
|
for i in range(len(results)): |
|
rgb = results[i]['outputs_coarse']['rgb'] |
|
rgbs.append(rgb) |
|
depth = results[i]['outputs_coarse']['depth'] |
|
depths.append(depth) |
|
|
|
rets = { |
|
'outputs_coarse': { |
|
'rgb': torch.cat(rgbs, dim=0).view(view_num, H, W, 3), |
|
'depth': torch.cat(depths, dim=0).view(view_num, H, W, 1), |
|
}, |
|
'gt_rgb': |
|
gt_rgb.view(view_num, H, W, 3), |
|
'gt_depth': |
|
gt_depth.view(view_num, H, W, 1) if gt_depth is not None else None, |
|
} |
|
else: |
|
rets = None |
|
return rets |
|
|