wyysf's picture
i
e814557
raw
history blame
No virus
6.9 kB
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import craftsman
from .utils import (
Mesh,
IsosurfaceHelper,
MarchingCubeCPUHelper,
MarchingTetrahedraHelper,
)
from craftsman.utils.base import BaseModule
from craftsman.utils.ops import chunk_batch, scale_tensor
from craftsman.utils.typing import *
class BaseGeometry(BaseModule):
@dataclass
class Config(BaseModule.Config):
pass
cfg: Config
@staticmethod
def create_from(
other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs
) -> "BaseGeometry":
raise TypeError(
f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}"
)
def export(self, *args, **kwargs):
return {}
class BaseImplicitGeometry(BaseGeometry):
@dataclass
class Config(BaseGeometry.Config):
radius: float = 1.0
isosurface: bool = True
isosurface_method: str = "mt"
isosurface_resolution: int = 128
isosurface_threshold: Union[float, str] = 0.0
isosurface_chunk: int = 0
isosurface_coarse_to_fine: bool = True
isosurface_deformable_grid: bool = False
isosurface_remove_outliers: bool = True
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
cfg: Config
def configure(self) -> None:
self.bbox: Float[Tensor, "2 3"]
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
],
dtype=torch.float32,
),
)
self.isosurface_helper: Optional[IsosurfaceHelper] = None
self.unbounded: bool = False
def _initilize_isosurface_helper(self):
if self.cfg.isosurface and self.isosurface_helper is None:
if self.cfg.isosurface_method == "mc-cpu":
self.isosurface_helper = MarchingCubeCPUHelper(
self.cfg.isosurface_resolution
).to(self.device)
elif self.cfg.isosurface_method == "mt":
self.isosurface_helper = MarchingTetrahedraHelper(
self.cfg.isosurface_resolution,
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
).to(self.device)
else:
raise AttributeError(
"Unknown isosurface method {self.cfg.isosurface_method}"
)
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
raise NotImplementedError
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
# return the value of the implicit field, could be density / signed distance
# also return a deformation field if the grid vertices can be optimized
raise NotImplementedError
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
# return the value of the implicit field, where the zero level set represents the surface
raise NotImplementedError
def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh:
def batch_func(x):
# scale to bbox as the input vertices are in [0, 1]
field, deformation = self.forward_field(
scale_tensor(
x.to(bbox.device), self.isosurface_helper.points_range, bbox
),
)
field = field.to(
x.device
) # move to the same device as the input (could be CPU)
if deformation is not None:
deformation = deformation.to(x.device)
return field, deformation
assert self.isosurface_helper is not None
field, deformation = chunk_batch(
batch_func,
self.cfg.isosurface_chunk,
self.isosurface_helper.grid_vertices,
)
threshold: float
if isinstance(self.cfg.isosurface_threshold, float):
threshold = self.cfg.isosurface_threshold
elif self.cfg.isosurface_threshold == "auto":
eps = 1.0e-5
threshold = field[field > eps].mean().item()
craftsman.info(
f"Automatically determined isosurface threshold: {threshold}"
)
else:
raise TypeError(
f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}"
)
level = self.forward_level(field, threshold)
mesh: Mesh = self.isosurface_helper(level, deformation=deformation)
mesh.v_pos = scale_tensor(
mesh.v_pos, self.isosurface_helper.points_range, bbox
) # scale to bbox as the grid vertices are in [0, 1]
mesh.add_extra("bbox", bbox)
if self.cfg.isosurface_remove_outliers:
# remove outliers components with small number of faces
# only enabled when the mesh is not differentiable
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
return mesh
def isosurface(self) -> Mesh:
if not self.cfg.isosurface:
raise NotImplementedError(
"Isosurface is not enabled in the current configuration"
)
self._initilize_isosurface_helper()
if self.cfg.isosurface_coarse_to_fine:
craftsman.debug("First run isosurface to get a tight bounding box ...")
with torch.no_grad():
mesh_coarse = self._isosurface(self.bbox)
vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0)
vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0])
vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1])
craftsman.debug("Run isosurface again with the tight bounding box ...")
mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True)
else:
mesh = self._isosurface(self.bbox)
return mesh
class BaseExplicitGeometry(BaseGeometry):
@dataclass
class Config(BaseGeometry.Config):
radius: float = 1.0
cfg: Config
def configure(self) -> None:
self.bbox: Float[Tensor, "2 3"]
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
],
dtype=torch.float32,
),
)