Spaces:
Running
Running
from typing import Callable, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchmcubes import marching_cubes | |
class IsosurfaceHelper(nn.Module): | |
points_range: Tuple[float, float] = (0, 1) | |
def grid_vertices(self) -> torch.FloatTensor: | |
raise NotImplementedError | |
class MarchingCubeHelper(IsosurfaceHelper): | |
def __init__(self, resolution: int) -> None: | |
super().__init__() | |
self.resolution = resolution | |
self.mc_func: Callable = marching_cubes | |
self._grid_vertices: Optional[torch.FloatTensor] = None | |
def grid_vertices(self) -> torch.FloatTensor: | |
if self._grid_vertices is None: | |
# keep the vertices on CPU so that we can support very large resolution | |
x, y, z = ( | |
torch.linspace(*self.points_range, self.resolution), | |
torch.linspace(*self.points_range, self.resolution), | |
torch.linspace(*self.points_range, self.resolution), | |
) | |
x, y, z = torch.meshgrid(x, y, z, indexing="ij") | |
verts = torch.cat( | |
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 | |
).reshape(-1, 3) | |
self._grid_vertices = verts | |
return self._grid_vertices | |
def forward( | |
self, | |
level: torch.FloatTensor, | |
) -> Tuple[torch.FloatTensor, torch.LongTensor]: | |
level = -level.view(self.resolution, self.resolution, self.resolution) | |
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0) | |
v_pos = v_pos[..., [2, 1, 0]] | |
v_pos = v_pos / (self.resolution - 1.0) | |
return v_pos, t_pos_idx | |