dmitriitochilkin's picture
add dependencies
ff49a48
from typing import Callable, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torchmcubes import marching_cubes
class IsosurfaceHelper(nn.Module):
points_range: Tuple[float, float] = (0, 1)
@property
def grid_vertices(self) -> torch.FloatTensor:
raise NotImplementedError
class MarchingCubeHelper(IsosurfaceHelper):
def __init__(self, resolution: int) -> None:
super().__init__()
self.resolution = resolution
self.mc_func: Callable = marching_cubes
self._grid_vertices: Optional[torch.FloatTensor] = None
@property
def grid_vertices(self) -> torch.FloatTensor:
if self._grid_vertices is None:
# keep the vertices on CPU so that we can support very large resolution
x, y, z = (
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
torch.linspace(*self.points_range, self.resolution),
)
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
verts = torch.cat(
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
).reshape(-1, 3)
self._grid_vertices = verts
return self._grid_vertices
def forward(
self,
level: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
level = -level.view(self.resolution, self.resolution, self.resolution)
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
v_pos = v_pos[..., [2, 1, 0]]
v_pos = v_pos / (self.resolution - 1.0)
return v_pos, t_pos_idx