Spaces:
Runtime error
Runtime error
File size: 6,414 Bytes
21c4e64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_normalized_directions(directions):
"""SH encoding must be in the range [0, 1]
Args:
directions: batch of directions
"""
return (directions + 1.0) / 2.0
def normalize_aabb(pts, aabb):
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
# no batch dimension present, need to add it
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim == 2 or grid_dim == 3:
grid_sampler = F.grid_sample
else:
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
f"implemented for 2 and 3D data.")
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
B, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = grid_sampler(
grid, # [B, feature_dim, reso, ...]
coords, # [B, 1, ..., n, grid_dim]
align_corners=align_corners,
mode='bilinear', padding_mode='border')
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
interp = interp.squeeze() # [B?, n, feature_dim?]
return interp
def init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5):
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
has_time_planes = in_dim == 4
assert grid_nd <= in_dim
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for ci, coo_comb in enumerate(coo_combs):
new_grid_coef = nn.Parameter(torch.empty(
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
))
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def interpolate_ms_features(pts: torch.Tensor,
ms_grids: Collection[Iterable[nn.Module]],
grid_dimensions: int,
concat_features: bool,
num_levels: Optional[int],
) -> torch.Tensor:
coo_combs = list(itertools.combinations(
range(pts.shape[-1]), grid_dimensions)
)
if num_levels is None:
num_levels = len(ms_grids)
multi_scale_interp = [] if concat_features else 0.
grid: nn.ParameterList
for scale_id, grid in enumerate(ms_grids[:num_levels]):
interp_space = 1.
for ci, coo_comb in enumerate(coo_combs):
# interpolate in plane
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
interp_out_plane = (
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
.view(-1, feature_dim)
)
# compute product over planes
interp_space = interp_space * interp_out_plane
# combine over scales
if concat_features:
multi_scale_interp.append(interp_space)
else:
multi_scale_interp = multi_scale_interp + interp_space
if concat_features:
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
return multi_scale_interp
class HexPlaneField(nn.Module):
def __init__(
self,
bounds,
planeconfig,
multires
) -> None:
super().__init__()
aabb = torch.tensor([[bounds,bounds,bounds],
[-bounds,-bounds,-bounds]])
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = [planeconfig]
self.multiscale_res_multipliers = multires
self.concat_features = True
# 1. Init planes
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multiscale_res_multipliers:
# initialize coordinate grid
config = self.grid_config[0].copy()
# Resolution fix: multi-res only on spatial planes
config["resolution"] = [
r * res for r in config["resolution"][:3]
] + config["resolution"][3:]
gp = init_grid_param(
grid_nd=config["grid_dimensions"],
in_dim=config["input_coordinate_dim"],
out_dim=config["output_coordinate_dim"],
reso=config["resolution"],
)
# shape[1] is out-dim - Concatenate over feature len for each scale
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
# print(f"Initialized model grids: {self.grids}")
print("feature_dim:",self.feat_dim)
def set_aabb(self,xyz_max, xyz_min):
aabb = torch.tensor([
xyz_max,
xyz_min
])
self.aabb = nn.Parameter(aabb,requires_grad=True)
print("Voxel Plane: set aabb=",self.aabb)
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features
|