|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import craftsman |
|
from craftsman.utils.typing import * |
|
|
|
|
|
def dot(x, y): |
|
return torch.sum(x * y, -1, keepdim=True) |
|
|
|
|
|
class Mesh: |
|
def __init__( |
|
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs |
|
) -> None: |
|
self.v_pos: Float[Tensor, "Nv 3"] = v_pos |
|
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx |
|
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None |
|
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None |
|
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None |
|
self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None |
|
self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None |
|
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None |
|
self.extras: Dict[str, Any] = {} |
|
for k, v in kwargs.items(): |
|
self.add_extra(k, v) |
|
|
|
def add_extra(self, k, v): |
|
self.extras[k] = v |
|
|
|
def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]): |
|
if self.requires_grad: |
|
craftsman.debug("Mesh is differentiable, not removing outliers") |
|
return self |
|
|
|
|
|
|
|
import trimesh |
|
|
|
|
|
mesh = trimesh.Trimesh( |
|
vertices=self.v_pos.detach().cpu().numpy(), |
|
faces=self.t_pos_idx.detach().cpu().numpy(), |
|
) |
|
|
|
|
|
components = mesh.split(only_watertight=False) |
|
|
|
craftsman.debug( |
|
"Mesh has {} components, with faces: {}".format( |
|
len(components), [c.faces.shape[0] for c in components] |
|
) |
|
) |
|
|
|
n_faces_threshold: int |
|
if isinstance(outlier_n_faces_threshold, float): |
|
|
|
n_faces_threshold = int( |
|
max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold |
|
) |
|
else: |
|
|
|
n_faces_threshold = outlier_n_faces_threshold |
|
|
|
|
|
craftsman.debug( |
|
"Removing components with less than {} faces".format(n_faces_threshold) |
|
) |
|
|
|
|
|
components = [c for c in components if c.faces.shape[0] >= n_faces_threshold] |
|
|
|
|
|
craftsman.debug( |
|
"Mesh has {} components after removing outliers, with faces: {}".format( |
|
len(components), [c.faces.shape[0] for c in components] |
|
) |
|
) |
|
|
|
mesh = trimesh.util.concatenate(components) |
|
|
|
|
|
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) |
|
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) |
|
|
|
clean_mesh = Mesh(v_pos, t_pos_idx) |
|
|
|
|
|
if len(self.extras) > 0: |
|
clean_mesh.extras = self.extras |
|
craftsman.debug( |
|
f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" |
|
) |
|
return clean_mesh |
|
|
|
@property |
|
def requires_grad(self): |
|
return self.v_pos.requires_grad |
|
|
|
@property |
|
def v_nrm(self): |
|
if self._v_nrm is None: |
|
self._v_nrm = self._compute_vertex_normal() |
|
return self._v_nrm |
|
|
|
@property |
|
def v_tng(self): |
|
if self._v_tng is None: |
|
self._v_tng = self._compute_vertex_tangent() |
|
return self._v_tng |
|
|
|
@property |
|
def v_tex(self): |
|
if self._v_tex is None: |
|
self._v_tex, self._t_tex_idx = self._unwrap_uv() |
|
return self._v_tex |
|
|
|
@property |
|
def t_tex_idx(self): |
|
if self._t_tex_idx is None: |
|
self._v_tex, self._t_tex_idx = self._unwrap_uv() |
|
return self._t_tex_idx |
|
|
|
@property |
|
def v_rgb(self): |
|
return self._v_rgb |
|
|
|
@property |
|
def edges(self): |
|
if self._edges is None: |
|
self._edges = self._compute_edges() |
|
return self._edges |
|
|
|
def _compute_vertex_normal(self): |
|
i0 = self.t_pos_idx[:, 0] |
|
i1 = self.t_pos_idx[:, 1] |
|
i2 = self.t_pos_idx[:, 2] |
|
|
|
v0 = self.v_pos[i0, :] |
|
v1 = self.v_pos[i1, :] |
|
v2 = self.v_pos[i2, :] |
|
|
|
face_normals = torch.cross(v1 - v0, v2 - v0) |
|
|
|
|
|
v_nrm = torch.zeros_like(self.v_pos) |
|
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) |
|
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) |
|
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) |
|
|
|
|
|
v_nrm = torch.where( |
|
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) |
|
) |
|
v_nrm = F.normalize(v_nrm, dim=1) |
|
|
|
if torch.is_anomaly_enabled(): |
|
assert torch.all(torch.isfinite(v_nrm)) |
|
|
|
return v_nrm |
|
|
|
def _compute_vertex_tangent(self): |
|
vn_idx = [None] * 3 |
|
pos = [None] * 3 |
|
tex = [None] * 3 |
|
for i in range(0, 3): |
|
pos[i] = self.v_pos[self.t_pos_idx[:, i]] |
|
tex[i] = self.v_tex[self.t_tex_idx[:, i]] |
|
|
|
vn_idx[i] = self.t_pos_idx[:, i] |
|
|
|
tangents = torch.zeros_like(self.v_nrm) |
|
tansum = torch.zeros_like(self.v_nrm) |
|
|
|
|
|
uve1 = tex[1] - tex[0] |
|
uve2 = tex[2] - tex[0] |
|
pe1 = pos[1] - pos[0] |
|
pe2 = pos[2] - pos[0] |
|
|
|
nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] |
|
denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] |
|
|
|
|
|
tang = nom / torch.where( |
|
denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) |
|
) |
|
|
|
|
|
for i in range(0, 3): |
|
idx = vn_idx[i][:, None].repeat(1, 3) |
|
tangents.scatter_add_(0, idx, tang) |
|
tansum.scatter_add_( |
|
0, idx, torch.ones_like(tang) |
|
) |
|
tangents = tangents / tansum |
|
|
|
|
|
tangents = F.normalize(tangents, dim=1) |
|
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) |
|
|
|
if torch.is_anomaly_enabled(): |
|
assert torch.all(torch.isfinite(tangents)) |
|
|
|
return tangents |
|
|
|
def _unwrap_uv( |
|
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} |
|
): |
|
craftsman.info("Using xatlas to perform UV unwrapping, may take a while ...") |
|
|
|
import xatlas |
|
|
|
atlas = xatlas.Atlas() |
|
atlas.add_mesh( |
|
self.v_pos.detach().cpu().numpy(), |
|
self.t_pos_idx.cpu().numpy(), |
|
) |
|
co = xatlas.ChartOptions() |
|
po = xatlas.PackOptions() |
|
for k, v in xatlas_chart_options.items(): |
|
setattr(co, k, v) |
|
for k, v in xatlas_pack_options.items(): |
|
setattr(po, k, v) |
|
|
|
setattr(co, 'max_cost', 2.0) |
|
setattr(po, 'resolution', 4096) |
|
|
|
atlas.generate(co, po, verbose=True) |
|
vmapping, indices, uvs = atlas.get_mesh(0) |
|
vmapping = ( |
|
torch.from_numpy( |
|
vmapping.astype(np.uint64, casting="same_kind").view(np.int64) |
|
) |
|
.to(self.v_pos.device) |
|
.long() |
|
) |
|
uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() |
|
indices = ( |
|
torch.from_numpy( |
|
indices.astype(np.uint64, casting="same_kind").view(np.int64) |
|
) |
|
.to(self.v_pos.device) |
|
.long() |
|
) |
|
return uvs, indices |
|
|
|
def unwrap_uv( |
|
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} |
|
): |
|
self._v_tex, self._t_tex_idx = self._unwrap_uv( |
|
xatlas_chart_options, xatlas_pack_options |
|
) |
|
|
|
def set_vertex_color(self, v_rgb): |
|
assert v_rgb.shape[0] == self.v_pos.shape[0] |
|
self._v_rgb = v_rgb |
|
|
|
def _compute_edges(self): |
|
|
|
edges = torch.cat( |
|
[ |
|
self.t_pos_idx[:, [0, 1]], |
|
self.t_pos_idx[:, [1, 2]], |
|
self.t_pos_idx[:, [2, 0]], |
|
], |
|
dim=0, |
|
) |
|
edges = edges.sort()[0] |
|
edges = torch.unique(edges, dim=0) |
|
return edges |
|
|
|
def normal_consistency(self) -> Float[Tensor, ""]: |
|
edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges] |
|
nc = ( |
|
1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) |
|
).mean() |
|
return nc |
|
|
|
def _laplacian_uniform(self): |
|
|
|
|
|
verts, faces = self.v_pos, self.t_pos_idx |
|
|
|
V = verts.shape[0] |
|
F = faces.shape[0] |
|
|
|
|
|
ii = faces[:, [1, 2, 0]].flatten() |
|
jj = faces[:, [2, 0, 1]].flatten() |
|
adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( |
|
dim=1 |
|
) |
|
adj_values = torch.ones(adj.shape[1]).to(verts) |
|
|
|
|
|
diag_idx = adj[0] |
|
|
|
|
|
idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) |
|
values = torch.cat((-adj_values, adj_values)) |
|
|
|
|
|
|
|
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() |
|
|
|
def laplacian(self) -> Float[Tensor, ""]: |
|
with torch.no_grad(): |
|
L = self._laplacian_uniform() |
|
loss = L.mm(self.v_pos) |
|
loss = loss.norm(dim=1) |
|
loss = loss.mean() |
|
return loss |
|
|
|
class IsosurfaceHelper(nn.Module): |
|
points_range: Tuple[float, float] = (0, 1) |
|
|
|
@property |
|
def grid_vertices(self) -> Float[Tensor, "N 3"]: |
|
raise NotImplementedError |
|
|
|
|
|
class MarchingCubeCPUHelper(IsosurfaceHelper): |
|
def __init__(self, resolution: int) -> None: |
|
super().__init__() |
|
self.resolution = resolution |
|
import mcubes |
|
|
|
self.mc_func: Callable = mcubes.marching_cubes |
|
self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None |
|
self._dummy: Float[Tensor, "..."] |
|
self.register_buffer( |
|
"_dummy", torch.zeros(0, dtype=torch.float32), persistent=False |
|
) |
|
|
|
@property |
|
def grid_vertices(self) -> Float[Tensor, "N3 3"]: |
|
if self._grid_vertices is None: |
|
|
|
x, y, z = ( |
|
torch.linspace(*self.points_range, self.resolution), |
|
torch.linspace(*self.points_range, self.resolution), |
|
torch.linspace(*self.points_range, self.resolution), |
|
) |
|
x, y, z = torch.meshgrid(x, y, z, indexing="ij") |
|
verts = torch.cat( |
|
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 |
|
).reshape(-1, 3) |
|
self._grid_vertices = verts |
|
return self._grid_vertices |
|
|
|
def forward( |
|
self, |
|
level: Float[Tensor, "N3 1"], |
|
deformation: Optional[Float[Tensor, "N3 3"]] = None, |
|
) -> Mesh: |
|
if deformation is not None: |
|
craftsman.warn( |
|
f"{self.__class__.__name__} does not support deformation. Ignoring." |
|
) |
|
level = -level.view(self.resolution, self.resolution, self.resolution) |
|
v_pos, t_pos_idx = self.mc_func( |
|
level.detach().cpu().numpy(), 0.0 |
|
) |
|
v_pos, t_pos_idx = ( |
|
torch.from_numpy(v_pos).float().to(self._dummy.device), |
|
torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), |
|
) |
|
v_pos = v_pos / (self.resolution - 1.0) |
|
return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) |
|
|
|
|
|
class MarchingTetrahedraHelper(IsosurfaceHelper): |
|
def __init__(self, resolution: int, tets_path: str): |
|
super().__init__() |
|
self.resolution = resolution |
|
self.tets_path = tets_path |
|
|
|
self.triangle_table: Float[Tensor, "..."] |
|
self.register_buffer( |
|
"triangle_table", |
|
torch.as_tensor( |
|
[ |
|
[-1, -1, -1, -1, -1, -1], |
|
[1, 0, 2, -1, -1, -1], |
|
[4, 0, 3, -1, -1, -1], |
|
[1, 4, 2, 1, 3, 4], |
|
[3, 1, 5, -1, -1, -1], |
|
[2, 3, 0, 2, 5, 3], |
|
[1, 4, 0, 1, 5, 4], |
|
[4, 2, 5, -1, -1, -1], |
|
[4, 5, 2, -1, -1, -1], |
|
[4, 1, 0, 4, 5, 1], |
|
[3, 2, 0, 3, 5, 2], |
|
[1, 3, 5, -1, -1, -1], |
|
[4, 1, 2, 4, 3, 1], |
|
[3, 0, 4, -1, -1, -1], |
|
[2, 0, 1, -1, -1, -1], |
|
[-1, -1, -1, -1, -1, -1], |
|
], |
|
dtype=torch.long, |
|
), |
|
persistent=False, |
|
) |
|
self.num_triangles_table: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"num_triangles_table", |
|
torch.as_tensor( |
|
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long |
|
), |
|
persistent=False, |
|
) |
|
self.base_tet_edges: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"base_tet_edges", |
|
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), |
|
persistent=False, |
|
) |
|
|
|
tets = np.load(self.tets_path) |
|
self._grid_vertices: Float[Tensor, "..."] |
|
self.register_buffer( |
|
"_grid_vertices", |
|
torch.from_numpy(tets["vertices"]).float(), |
|
persistent=False, |
|
) |
|
self.indices: Integer[Tensor, "..."] |
|
self.register_buffer( |
|
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False |
|
) |
|
|
|
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None |
|
|
|
def normalize_grid_deformation( |
|
self, grid_vertex_offsets: Float[Tensor, "Nv 3"] |
|
) -> Float[Tensor, "Nv 3"]: |
|
return ( |
|
(self.points_range[1] - self.points_range[0]) |
|
/ (self.resolution) |
|
* torch.tanh(grid_vertex_offsets) |
|
) |
|
|
|
@property |
|
def grid_vertices(self) -> Float[Tensor, "Nv 3"]: |
|
return self._grid_vertices |
|
|
|
@property |
|
def all_edges(self) -> Integer[Tensor, "Ne 2"]: |
|
if self._all_edges is None: |
|
|
|
edges = torch.tensor( |
|
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], |
|
dtype=torch.long, |
|
device=self.indices.device, |
|
) |
|
_all_edges = self.indices[:, edges].reshape(-1, 2) |
|
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0] |
|
_all_edges = torch.unique(_all_edges_sorted, dim=0) |
|
self._all_edges = _all_edges |
|
return self._all_edges |
|
|
|
def sort_edges(self, edges_ex2): |
|
with torch.no_grad(): |
|
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() |
|
order = order.unsqueeze(dim=1) |
|
|
|
a = torch.gather(input=edges_ex2, index=order, dim=1) |
|
b = torch.gather(input=edges_ex2, index=1 - order, dim=1) |
|
|
|
return torch.stack([a, b], -1) |
|
|
|
def _forward(self, pos_nx3, sdf_n, tet_fx4): |
|
with torch.no_grad(): |
|
occ_n = sdf_n > 0 |
|
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) |
|
occ_sum = torch.sum(occ_fx4, -1) |
|
valid_tets = (occ_sum > 0) & (occ_sum < 4) |
|
occ_sum = occ_sum[valid_tets] |
|
|
|
|
|
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) |
|
all_edges = self.sort_edges(all_edges) |
|
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) |
|
|
|
unique_edges = unique_edges.long() |
|
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 |
|
mapping = ( |
|
torch.ones( |
|
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device |
|
) |
|
* -1 |
|
) |
|
mapping[mask_edges] = torch.arange( |
|
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device |
|
) |
|
idx_map = mapping[idx_map] |
|
|
|
interp_v = unique_edges[mask_edges] |
|
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) |
|
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) |
|
edges_to_interp_sdf[:, -1] *= -1 |
|
|
|
denominator = edges_to_interp_sdf.sum(1, keepdim=True) |
|
|
|
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator |
|
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) |
|
|
|
idx_map = idx_map.reshape(-1, 6) |
|
|
|
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) |
|
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) |
|
num_triangles = self.num_triangles_table[tetindex] |
|
|
|
|
|
faces = torch.cat( |
|
( |
|
torch.gather( |
|
input=idx_map[num_triangles == 1], |
|
dim=1, |
|
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], |
|
).reshape(-1, 3), |
|
torch.gather( |
|
input=idx_map[num_triangles == 2], |
|
dim=1, |
|
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], |
|
).reshape(-1, 3), |
|
), |
|
dim=0, |
|
) |
|
|
|
return verts, faces |
|
|
|
def forward( |
|
self, |
|
level: Float[Tensor, "N3 1"], |
|
deformation: Optional[Float[Tensor, "N3 3"]] = None, |
|
) -> Mesh: |
|
if deformation is not None: |
|
grid_vertices = self.grid_vertices + self.normalize_grid_deformation( |
|
deformation |
|
) |
|
else: |
|
grid_vertices = self.grid_vertices |
|
|
|
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) |
|
|
|
mesh = Mesh( |
|
v_pos=v_pos, |
|
t_pos_idx=t_pos_idx, |
|
|
|
grid_vertices=grid_vertices, |
|
tet_edges=self.all_edges, |
|
grid_level=level, |
|
grid_deformation=deformation, |
|
) |
|
|
|
return mesh |
|
|