from typing import * import torch import torch.nn as nn from . import BACKEND, DEBUG SparseTensorData = None # Lazy import __all__ = [ "SparseTensor", "sparse_batch_broadcast", "sparse_batch_op", "sparse_cat", "sparse_unbind", ] class SparseTensor: """ Sparse tensor with support for both torchsparse and spconv backends. Parameters: - feats (torch.Tensor): Features of the sparse tensor. - coords (torch.Tensor): Coordinates of the sparse tensor. - shape (torch.Size): Shape of the sparse tensor. - layout (List[slice]): Layout of the sparse tensor for each batch - data (SparseTensorData): Sparse tensor data used for convolusion NOTE: - Data corresponding to a same batch should be contiguous. - Coords should be in [0, 1023] """ @overload def __init__( self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs, ): ... @overload def __init__( self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs, ): ... def __init__(self, *args, **kwargs): # Lazy import of sparse tensor backend global SparseTensorData if SparseTensorData is None: import importlib if BACKEND == "torchsparse": SparseTensorData = importlib.import_module("torchsparse").SparseTensor elif BACKEND == "spconv": SparseTensorData = importlib.import_module( "spconv.pytorch" ).SparseConvTensor method_id = 0 if len(args) != 0: method_id = 0 if isinstance(args[0], torch.Tensor) else 1 else: method_id = 1 if "data" in kwargs else 0 if method_id == 0: feats, coords, shape, layout = args + (None,) * (4 - len(args)) if "feats" in kwargs: feats = kwargs["feats"] del kwargs["feats"] if "coords" in kwargs: coords = kwargs["coords"] del kwargs["coords"] if "shape" in kwargs: shape = kwargs["shape"] del kwargs["shape"] if "layout" in kwargs: layout = kwargs["layout"] del kwargs["layout"] if shape is None: shape = self.__cal_shape(feats, coords) if layout is None: layout = self.__cal_layout(coords, shape[0]) if BACKEND == "torchsparse": self.data = SparseTensorData(feats, coords, **kwargs) elif BACKEND == "spconv": spatial_shape = list(coords.max(0)[0] + 1)[1:] self.data = SparseTensorData( feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs, ) self.data._features = feats elif method_id == 1: data, shape, layout = args + (None,) * (3 - len(args)) if "data" in kwargs: data = kwargs["data"] del kwargs["data"] if "shape" in kwargs: shape = kwargs["shape"] del kwargs["shape"] if "layout" in kwargs: layout = kwargs["layout"] del kwargs["layout"] self.data = data if shape is None: shape = self.__cal_shape(self.feats, self.coords) if layout is None: layout = self.__cal_layout(self.coords, shape[0]) self._shape = shape self._layout = layout self._scale = kwargs.get("scale", (1, 1, 1)) self._spatial_cache = kwargs.get("spatial_cache", {}) if DEBUG: try: assert ( self.feats.shape[0] == self.coords.shape[0] ), f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" assert self.shape == self.__cal_shape( self.feats, self.coords ), f"Invalid shape: {self.shape}" assert self.layout == self.__cal_layout( self.coords, self.shape[0] ), f"Invalid layout: {self.layout}" for i in range(self.shape[0]): assert torch.all( self.coords[self.layout[i], 0] == i ), f"The data of batch {i} is not contiguous" except Exception as e: print("Debugging information:") print(f"- Shape: {self.shape}") print(f"- Layout: {self.layout}") print(f"- Scale: {self._scale}") print(f"- Coords: {self.coords}") raise e def __cal_shape(self, feats, coords): shape = [] shape.append(coords[:, 0].max().item() + 1) shape.extend([*feats.shape[1:]]) return torch.Size(shape) def __cal_layout(self, coords, batch_size): seq_len = torch.bincount(coords[:, 0], minlength=batch_size) offset = torch.cumsum(seq_len, dim=0) layout = [ slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size) ] return layout @property def shape(self) -> torch.Size: return self._shape def dim(self) -> int: return len(self.shape) @property def layout(self) -> List[slice]: return self._layout @property def feats(self) -> torch.Tensor: if BACKEND == "torchsparse": return self.data.F elif BACKEND == "spconv": return self.data.features @feats.setter def feats(self, value: torch.Tensor): if BACKEND == "torchsparse": self.data.F = value elif BACKEND == "spconv": self.data.features = value @property def coords(self) -> torch.Tensor: if BACKEND == "torchsparse": return self.data.C elif BACKEND == "spconv": return self.data.indices @coords.setter def coords(self, value: torch.Tensor): if BACKEND == "torchsparse": self.data.C = value elif BACKEND == "spconv": self.data.indices = value @property def dtype(self): return self.feats.dtype @property def device(self): return self.feats.device @overload def to(self, dtype: torch.dtype) -> "SparseTensor": ... @overload def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> "SparseTensor": ... def to(self, *args, **kwargs) -> "SparseTensor": device = None dtype = None if len(args) == 2: device, dtype = args elif len(args) == 1: if isinstance(args[0], torch.dtype): dtype = args[0] else: device = args[0] if "dtype" in kwargs: assert dtype is None, "to() received multiple values for argument 'dtype'" dtype = kwargs["dtype"] if "device" in kwargs: assert device is None, "to() received multiple values for argument 'device'" device = kwargs["device"] new_feats = self.feats.to(device=device, dtype=dtype) new_coords = self.coords.to(device=device) return self.replace(new_feats, new_coords) def type(self, dtype): new_feats = self.feats.type(dtype) return self.replace(new_feats) def cpu(self) -> "SparseTensor": new_feats = self.feats.cpu() new_coords = self.coords.cpu() return self.replace(new_feats, new_coords) def cuda(self) -> "SparseTensor": new_feats = self.feats.cuda() new_coords = self.coords.cuda() return self.replace(new_feats, new_coords) def half(self) -> "SparseTensor": new_feats = self.feats.half() return self.replace(new_feats) def float(self) -> "SparseTensor": new_feats = self.feats.float() return self.replace(new_feats) def detach(self) -> "SparseTensor": new_coords = self.coords.detach() new_feats = self.feats.detach() return self.replace(new_feats, new_coords) def dense(self) -> torch.Tensor: if BACKEND == "torchsparse": return self.data.dense() elif BACKEND == "spconv": return self.data.dense() def reshape(self, *shape) -> "SparseTensor": new_feats = self.feats.reshape(self.feats.shape[0], *shape) return self.replace(new_feats) def unbind(self, dim: int) -> List["SparseTensor"]: return sparse_unbind(self, dim) def replace( self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None ) -> "SparseTensor": new_shape = [self.shape[0]] new_shape.extend(feats.shape[1:]) if BACKEND == "torchsparse": new_data = SparseTensorData( feats=feats, coords=self.data.coords if coords is None else coords, stride=self.data.stride, spatial_range=self.data.spatial_range, ) new_data._caches = self.data._caches elif BACKEND == "spconv": new_data = SparseTensorData( self.data.features.reshape(self.data.features.shape[0], -1), self.data.indices, self.data.spatial_shape, self.data.batch_size, self.data.grid, self.data.voxel_num, self.data.indice_dict, ) new_data._features = feats new_data.benchmark = self.data.benchmark new_data.benchmark_record = self.data.benchmark_record new_data.thrust_allocator = self.data.thrust_allocator new_data._timer = self.data._timer new_data.force_algo = self.data.force_algo new_data.int8_scale = self.data.int8_scale if coords is not None: new_data.indices = coords new_tensor = SparseTensor( new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache, ) return new_tensor @staticmethod def full(aabb, dim, value, dtype=torch.float32, device=None) -> "SparseTensor": N, C = dim x = torch.arange(aabb[0], aabb[3] + 1) y = torch.arange(aabb[1], aabb[4] + 1) z = torch.arange(aabb[2], aabb[5] + 1) coords = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1).reshape( -1, 3 ) coords = torch.cat( [ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), coords.repeat(N, 1), ], dim=1, ).to(dtype=torch.int32, device=device) feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) return SparseTensor(feats=feats, coords=coords) def __merge_sparse_cache(self, other: "SparseTensor") -> dict: new_cache = {} for k in set( list(self._spatial_cache.keys()) + list(other._spatial_cache.keys()) ): if k in self._spatial_cache: new_cache[k] = self._spatial_cache[k] if k in other._spatial_cache: if k not in new_cache: new_cache[k] = other._spatial_cache[k] else: new_cache[k].update(other._spatial_cache[k]) return new_cache def __neg__(self) -> "SparseTensor": return self.replace(-self.feats) def __elemwise__( self, other: Union[torch.Tensor, "SparseTensor"], op: callable ) -> "SparseTensor": if isinstance(other, torch.Tensor): try: other = torch.broadcast_to(other, self.shape) other = sparse_batch_broadcast(self, other) except: pass if isinstance(other, SparseTensor): other = other.feats new_feats = op(self.feats, other) new_tensor = self.replace(new_feats) if isinstance(other, SparseTensor): new_tensor._spatial_cache = self.__merge_sparse_cache(other) return new_tensor def __add__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.add) def __radd__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.add) def __sub__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.sub) def __rsub__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) def __mul__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.mul) def __rmul__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.mul) def __truediv__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, torch.div) def __rtruediv__( self, other: Union[torch.Tensor, "SparseTensor", float] ) -> "SparseTensor": return self.__elemwise__(other, lambda x, y: torch.div(y, x)) def __getitem__(self, idx): if isinstance(idx, int): idx = [idx] elif isinstance(idx, slice): idx = range(*idx.indices(self.shape[0])) elif isinstance(idx, torch.Tensor): if idx.dtype == torch.bool: assert idx.shape == ( self.shape[0], ), f"Invalid index shape: {idx.shape}" idx = idx.nonzero().squeeze(1) elif idx.dtype in [torch.int32, torch.int64]: assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" else: raise ValueError(f"Unknown index type: {idx.dtype}") else: raise ValueError(f"Unknown index type: {type(idx)}") coords = [] feats = [] for new_idx, old_idx in enumerate(idx): coords.append(self.coords[self.layout[old_idx]].clone()) coords[-1][:, 0] = new_idx feats.append(self.feats[self.layout[old_idx]]) coords = torch.cat(coords, dim=0).contiguous() feats = torch.cat(feats, dim=0).contiguous() return SparseTensor(feats=feats, coords=coords) def register_spatial_cache(self, key, value) -> None: """ Register a spatial cache. The spatial cache can be any thing you want to cache. The registery and retrieval of the cache is based on current scale. """ scale_key = str(self._scale) if scale_key not in self._spatial_cache: self._spatial_cache[scale_key] = {} self._spatial_cache[scale_key][key] = value def get_spatial_cache(self, key=None): """ Get a spatial cache. """ scale_key = str(self._scale) cur_scale_cache = self._spatial_cache.get(scale_key, {}) if key is None: return cur_scale_cache return cur_scale_cache.get(key, None) def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: """ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. Args: input (torch.Tensor): 1D tensor to broadcast. target (SparseTensor): Sparse tensor to broadcast to. op (callable): Operation to perform after broadcasting. Defaults to torch.add. """ coords, feats = input.coords, input.feats broadcasted = torch.zeros_like(feats) for k in range(input.shape[0]): broadcasted[input.layout[k]] = other[k] return broadcasted def sparse_batch_op( input: SparseTensor, other: torch.Tensor, op: callable = torch.add ) -> SparseTensor: """ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. Args: input (torch.Tensor): 1D tensor to broadcast. target (SparseTensor): Sparse tensor to broadcast to. op (callable): Operation to perform after broadcasting. Defaults to torch.add. """ return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: """ Concatenate a list of sparse tensors. Args: inputs (List[SparseTensor]): List of sparse tensors to concatenate. """ if dim == 0: start = 0 coords = [] for input in inputs: coords.append(input.coords.clone()) coords[-1][:, 0] += start start += input.shape[0] coords = torch.cat(coords, dim=0) feats = torch.cat([input.feats for input in inputs], dim=0) output = SparseTensor( coords=coords, feats=feats, ) else: feats = torch.cat([input.feats for input in inputs], dim=dim) output = inputs[0].replace(feats) return output def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: """ Unbind a sparse tensor along a dimension. Args: input (SparseTensor): Sparse tensor to unbind. dim (int): Dimension to unbind. """ if dim == 0: return [input[i] for i in range(input.shape[0])] else: feats = input.feats.unbind(dim) return [input.replace(f) for f in feats]