import itertools import logging as log from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable import torch import torch.nn as nn import torch.nn.functional as F def get_normalized_directions(directions): """SH encoding must be in the range [0, 1] Args: directions: batch of directions """ return (directions + 1.0) / 2.0 def normalize_aabb(pts, aabb): return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: grid_dim = coords.shape[-1] if grid.dim() == grid_dim + 1: # no batch dimension present, need to add it grid = grid.unsqueeze(0) if coords.dim() == 2: coords = coords.unsqueeze(0) if grid_dim == 2 or grid_dim == 3: grid_sampler = F.grid_sample else: raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " f"implemented for 2 and 3D data.") coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) B, feature_dim = grid.shape[:2] n = coords.shape[-2] interp = grid_sampler( grid, # [B, feature_dim, reso, ...] coords, # [B, 1, ..., n, grid_dim] align_corners=align_corners, mode='bilinear', padding_mode='border') interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] interp = interp.squeeze() # [B?, n, feature_dim?] return interp def init_grid_param( grid_nd: int, in_dim: int, out_dim: int, reso: Sequence[int], a: float = 0.1, b: float = 0.5): assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" has_time_planes = in_dim == 4 assert grid_nd <= in_dim coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) grid_coefs = nn.ParameterList() for ci, coo_comb in enumerate(coo_combs): new_grid_coef = nn.Parameter(torch.empty( [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] )) if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 nn.init.ones_(new_grid_coef) else: nn.init.uniform_(new_grid_coef, a=a, b=b) grid_coefs.append(new_grid_coef) return grid_coefs def interpolate_ms_features(pts: torch.Tensor, ms_grids: Collection[Iterable[nn.Module]], grid_dimensions: int, concat_features: bool, num_levels: Optional[int], ) -> torch.Tensor: coo_combs = list(itertools.combinations( range(pts.shape[-1]), grid_dimensions) ) if num_levels is None: num_levels = len(ms_grids) multi_scale_interp = [] if concat_features else 0. grid: nn.ParameterList for scale_id, grid in enumerate(ms_grids[:num_levels]): interp_space = 1. for ci, coo_comb in enumerate(coo_combs): # interpolate in plane feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso interp_out_plane = ( grid_sample_wrapper(grid[ci], pts[..., coo_comb]) .view(-1, feature_dim) ) # compute product over planes interp_space = interp_space * interp_out_plane # combine over scales if concat_features: multi_scale_interp.append(interp_space) else: multi_scale_interp = multi_scale_interp + interp_space if concat_features: multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) return multi_scale_interp class HexPlaneField(nn.Module): def __init__( self, bounds, planeconfig, multires ) -> None: super().__init__() aabb = torch.tensor([[bounds,bounds,bounds], [-bounds,-bounds,-bounds]]) self.aabb = nn.Parameter(aabb, requires_grad=False) self.grid_config = [planeconfig] self.multiscale_res_multipliers = multires self.concat_features = True # 1. Init planes self.grids = nn.ModuleList() self.feat_dim = 0 for res in self.multiscale_res_multipliers: # initialize coordinate grid config = self.grid_config[0].copy() # Resolution fix: multi-res only on spatial planes config["resolution"] = [ r * res for r in config["resolution"][:3] ] + config["resolution"][3:] gp = init_grid_param( grid_nd=config["grid_dimensions"], in_dim=config["input_coordinate_dim"], out_dim=config["output_coordinate_dim"], reso=config["resolution"], ) # shape[1] is out-dim - Concatenate over feature len for each scale if self.concat_features: self.feat_dim += gp[-1].shape[1] else: self.feat_dim = gp[-1].shape[1] self.grids.append(gp) # print(f"Initialized model grids: {self.grids}") print("feature_dim:",self.feat_dim) def set_aabb(self,xyz_max, xyz_min): aabb = torch.tensor([ xyz_max, xyz_min ]) self.aabb = nn.Parameter(aabb,requires_grad=True) print("Voxel Plane: set aabb=",self.aabb) def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): """Computes and returns the densities.""" pts = normalize_aabb(pts, self.aabb) pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] pts = pts.reshape(-1, pts.shape[-1]) features = interpolate_ms_features( pts, ms_grids=self.grids, # noqa grid_dimensions=self.grid_config[0]["grid_dimensions"], concat_features=self.concat_features, num_levels=None) if len(features) < 1: features = torch.zeros((0, 1)).to(features.device) return features def forward(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): features = self.get_density(pts, timestamps) return features