|
import torch |
|
import torch.nn.functional as tfunc |
|
import torch_scatter |
|
|
|
def prepend_dummies( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
)->tuple[torch.Tensor,torch.Tensor]: |
|
"""prepend dummy elements to vertices and faces to enable "masked" scatter operations""" |
|
V,D = vertices.shape |
|
vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) |
|
faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) |
|
return vertices,faces |
|
|
|
def remove_dummies( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
)->tuple[torch.Tensor,torch.Tensor]: |
|
"""remove dummy elements added with prepend_dummies()""" |
|
return vertices[1:],faces[1:]-1 |
|
|
|
|
|
def calc_edges( |
|
faces: torch.Tensor, |
|
with_edge_to_face: bool = False |
|
) -> tuple[torch.Tensor, ...]: |
|
""" |
|
returns tuple of |
|
- edges E,2 long, 0 for unused, lower vertex index first |
|
- face_to_edge F,3 long |
|
- (optional) edge_to_face shape=E,[left,right],[face,side] |
|
|
|
o-<-----e1 e0,e1...edge, e0<e1 |
|
| /A L,R....left and right face |
|
| L / | both triangles ordered counter clockwise |
|
| / R | normals pointing out of screen |
|
V/ | |
|
e0---->-o |
|
""" |
|
|
|
F = faces.shape[0] |
|
|
|
|
|
face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) |
|
full_edges = face_edges.reshape(F*3,2) |
|
sorted_edges,_ = full_edges.sort(dim=-1) |
|
|
|
|
|
edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) |
|
E = edges.shape[0] |
|
face_to_edge = full_to_unique.reshape(F,3) |
|
|
|
if not with_edge_to_face: |
|
return edges, face_to_edge |
|
|
|
is_right = full_edges[:,0]!=sorted_edges[:,0] |
|
edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) |
|
scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) |
|
edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) |
|
edge_to_face[0] = 0 |
|
return edges, face_to_edge, edge_to_face |
|
|
|
def calc_edge_length( |
|
vertices:torch.Tensor, |
|
edges:torch.Tensor, |
|
)->torch.Tensor: |
|
|
|
full_vertices = vertices[edges] |
|
a,b = full_vertices.unbind(dim=1) |
|
return torch.norm(a-b,p=2,dim=-1) |
|
|
|
def calc_face_normals( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
normalize:bool=False, |
|
)->torch.Tensor: |
|
""" |
|
n |
|
| |
|
c0 corners ordered counterclockwise when |
|
/ \ looking onto surface (in neg normal direction) |
|
c1---c2 |
|
""" |
|
full_vertices = vertices[faces] |
|
v0,v1,v2 = full_vertices.unbind(dim=1) |
|
face_normals = torch.cross(v1-v0,v2-v0, dim=1) |
|
if normalize: |
|
face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) |
|
return face_normals |
|
|
|
def calc_vertex_normals( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
face_normals:torch.Tensor=None, |
|
)->torch.Tensor: |
|
|
|
F = faces.shape[0] |
|
|
|
if face_normals is None: |
|
face_normals = calc_face_normals(vertices,faces) |
|
|
|
vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) |
|
vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) |
|
vertex_normals = vertex_normals.sum(dim=1) |
|
return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) |
|
|
|
def calc_face_ref_normals( |
|
faces:torch.Tensor, |
|
vertex_normals:torch.Tensor, |
|
normalize:bool=False, |
|
)->torch.Tensor: |
|
"""calculate reference normals for face flip detection""" |
|
full_normals = vertex_normals[faces] |
|
ref_normals = full_normals.sum(dim=1) |
|
if normalize: |
|
ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) |
|
return ref_normals |
|
|
|
def pack( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
)->tuple[torch.Tensor,torch.Tensor]: |
|
"""removes unused elements in vertices and faces""" |
|
V = vertices.shape[0] |
|
|
|
|
|
used_faces = faces[:,0]!=0 |
|
used_faces[0] = True |
|
faces = faces[used_faces] |
|
|
|
|
|
used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) |
|
used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') |
|
used_vertices = used_vertices.any(dim=1) |
|
used_vertices[0] = True |
|
vertices = vertices[used_vertices] |
|
|
|
|
|
ind = torch.zeros(V,dtype=torch.long,device=vertices.device) |
|
V1 = used_vertices.sum() |
|
ind[used_vertices] = torch.arange(0,V1,device=vertices.device) |
|
faces = ind[faces] |
|
|
|
return vertices,faces |
|
|
|
def split_edges( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
edges:torch.Tensor, |
|
face_to_edge:torch.Tensor, |
|
splits, |
|
pack_faces:bool=True, |
|
)->tuple[torch.Tensor,torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
V = vertices.shape[0] |
|
F = faces.shape[0] |
|
S = splits.sum().item() |
|
|
|
if S==0: |
|
return vertices,faces |
|
|
|
edge_vert = torch.zeros_like(splits, dtype=torch.long) |
|
edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) |
|
side_vert = edge_vert[face_to_edge] |
|
split_edges = edges[splits] |
|
|
|
|
|
split_vertices = vertices[split_edges].mean(dim=1) |
|
vertices = torch.concat((vertices,split_vertices),dim=0) |
|
|
|
|
|
side_split = side_vert!=0 |
|
shrunk_faces = torch.where(side_split,side_vert,faces) |
|
new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) |
|
faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) |
|
if pack_faces: |
|
mask = faces[:,0]!=0 |
|
mask[0] = True |
|
faces = faces[mask] |
|
|
|
return vertices,faces |
|
|
|
def collapse_edges( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
edges:torch.Tensor, |
|
priorities:torch.Tensor, |
|
stable:bool=False, |
|
)->tuple[torch.Tensor,torch.Tensor]: |
|
|
|
V = vertices.shape[0] |
|
|
|
|
|
_,order = priorities.sort(stable=stable) |
|
rank = torch.zeros_like(order) |
|
rank[order] = torch.arange(0,len(rank),device=rank.device) |
|
vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) |
|
edge_rank = rank |
|
for i in range(3): |
|
torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) |
|
edge_rank,_ = vert_rank[edges].max(dim=-1) |
|
candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] |
|
|
|
|
|
vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) |
|
vert_connections[candidates[:,0]] = 1 |
|
edge_connections = vert_connections[edges].sum(dim=-1) |
|
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) |
|
vert_connections[candidates] = 0 |
|
edge_connections = vert_connections[edges].sum(dim=-1) |
|
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) |
|
collapses = candidates[vert_connections[candidates[:,1]] <= 2] |
|
|
|
|
|
vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) |
|
|
|
|
|
dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) |
|
dest[collapses[:,1]] = dest[collapses[:,0]] |
|
faces = dest[faces] |
|
c0,c1,c2 = faces.unbind(dim=-1) |
|
collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) |
|
faces[collapsed] = 0 |
|
|
|
return vertices,faces |
|
|
|
def calc_face_collapses( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
edges:torch.Tensor, |
|
face_to_edge:torch.Tensor, |
|
edge_length:torch.Tensor, |
|
face_normals:torch.Tensor, |
|
vertex_normals:torch.Tensor, |
|
min_edge_length:torch.Tensor=None, |
|
area_ratio = 0.5, |
|
shortest_probability = 0.8 |
|
)->torch.Tensor: |
|
|
|
E = edges.shape[0] |
|
F = faces.shape[0] |
|
|
|
|
|
ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) |
|
face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 |
|
|
|
|
|
if min_edge_length is not None: |
|
min_face_length = min_edge_length[faces].mean(dim=-1) |
|
min_area = min_face_length**2 * area_ratio |
|
face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) |
|
face_collapses[0] = False |
|
|
|
|
|
face_length = edge_length[face_to_edge] |
|
|
|
if shortest_probability<1: |
|
|
|
randlim = round(2/(1-shortest_probability)) |
|
rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) |
|
sort_ind = torch.argsort(face_length,dim=-1,descending=True) |
|
local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) |
|
else: |
|
local_ind = torch.argmin(face_length,dim=-1)[:,None] |
|
|
|
edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] |
|
edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) |
|
edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) |
|
|
|
return edge_collapses.bool() |
|
|
|
def flip_edges( |
|
vertices:torch.Tensor, |
|
faces:torch.Tensor, |
|
edges:torch.Tensor, |
|
edge_to_face:torch.Tensor, |
|
with_border:bool=True, |
|
with_normal_check:bool=True, |
|
stable:bool=False, |
|
): |
|
V = vertices.shape[0] |
|
E = edges.shape[0] |
|
device=vertices.device |
|
vertex_degree = torch.zeros(V,dtype=torch.long,device=device) |
|
vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') |
|
neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 |
|
neighbors = faces[edge_to_face[:,:,0],neighbor_corner] |
|
edge_is_inside = neighbors.all(dim=-1) |
|
|
|
if with_border: |
|
|
|
|
|
vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) |
|
src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) |
|
vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') |
|
vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) |
|
vertex_degree -= 2 * vertex_is_inside |
|
|
|
neighbor_degrees = vertex_degree[neighbors] |
|
edge_degrees = vertex_degree[edges] |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) |
|
candidates = torch.logical_and(loss_change<0, edge_is_inside) |
|
loss_change = loss_change[candidates] |
|
if loss_change.shape[0]==0: |
|
return |
|
|
|
edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) |
|
_,order = loss_change.sort(descending=True, stable=stable) |
|
rank = torch.zeros_like(order) |
|
rank[order] = torch.arange(0,len(rank),device=rank.device) |
|
vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) |
|
torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) |
|
vertex_rank,_ = vertex_rank.max(dim=-1) |
|
neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) |
|
flip = rank==neighborhood_rank |
|
|
|
if with_normal_check: |
|
|
|
|
|
|
|
|
|
|
|
|
|
v = vertices[edges_neighbors] |
|
v = v - v[:,0:1] |
|
e1 = v[:,1] |
|
cl = v[:,2] |
|
cr = v[:,3] |
|
n = torch.cross(e1,cl) + torch.cross(cr,e1) |
|
flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) |
|
flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) |
|
|
|
flip_edges_neighbors = edges_neighbors[flip] |
|
flip_edge_to_face = edge_to_face[candidates,:,0][flip] |
|
flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] |
|
faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) |
|
|