Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,592 Bytes
c371ec2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from __future__ import annotations
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from jaxtyping import Float, Integer
from torch import Tensor
from sf3d.box_uv_unwrap import box_projection_uv_unwrap
from sf3d.models.utils import dot
class Mesh:
def __init__(
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
) -> None:
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
self.extras: Dict[str, Any] = {}
for k, v in kwargs.items():
self.add_extra(k, v)
def add_extra(self, k, v) -> None:
self.extras[k] = v
@property
def requires_grad(self):
return self.v_pos.requires_grad
@property
def v_nrm(self):
if self._v_nrm is None:
self._v_nrm = self._compute_vertex_normal()
return self._v_nrm
@property
def v_tng(self):
if self._v_tng is None:
self._v_tng = self._compute_vertex_tangent()
return self._v_tng
@property
def v_tex(self):
if self._v_tex is None:
self.unwrap_uv()
return self._v_tex
@property
def edges(self):
if self._edges is None:
self._edges = self._compute_edges()
return self._edges
def _compute_vertex_normal(self):
i0 = self.t_pos_idx[:, 0]
i1 = self.t_pos_idx[:, 1]
i2 = self.t_pos_idx[:, 2]
v0 = self.v_pos[i0, :]
v1 = self.v_pos[i1, :]
v2 = self.v_pos[i2, :]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
# Splat face normals to vertices
v_nrm = torch.zeros_like(self.v_pos)
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
# Normalize, replace zero (degenerated) normals with some default value
v_nrm = torch.where(
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
)
v_nrm = F.normalize(v_nrm, dim=1)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(v_nrm))
return v_nrm
def _compute_vertex_tangent(self):
vn_idx = [None] * 3
pos = [None] * 3
tex = [None] * 3
for i in range(0, 3):
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
# t_nrm_idx is always the same as t_pos_idx
vn_idx[i] = self.t_pos_idx[:, i]
tangents = torch.zeros_like(self.v_nrm)
tansum = torch.zeros_like(self.v_nrm)
# Compute tangent space for each triangle
duv1 = tex[1] - tex[0]
duv2 = tex[2] - tex[0]
dpos1 = pos[1] - pos[0]
dpos2 = pos[2] - pos[0]
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
# Avoid division by zero for degenerated texture coordinates
denom_safe = denom.clip(1e-6)
tang = tng_nom / denom_safe
# Update all 3 vertices
for i in range(0, 3):
idx = vn_idx[i][:, None].repeat(1, 3)
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
tansum.scatter_add_(
0, idx, torch.ones_like(tang)
) # tansum[n_i] = tansum[n_i] + 1
# Also normalize it. Here we do not normalize the individual triangles first so larger area
# triangles influence the tangent space more
tangents = tangents / tansum
# Normalize and make sure tangent is perpendicular to normal
tangents = F.normalize(tangents, dim=1)
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(tangents))
return tangents
@torch.no_grad()
def unwrap_uv(
self,
island_padding: float = 0.02,
) -> Mesh:
uv, indices = box_projection_uv_unwrap(
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
)
# Do store per vertex UVs.
# This means we need to duplicate some vertices at the seams
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
individual_faces = torch.arange(
individual_vertices.shape[0],
device=individual_vertices.device,
dtype=self.t_pos_idx.dtype,
).reshape(-1, 3)
uv_flat = uv[indices].reshape((-1, 2))
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
self.v_pos = individual_vertices
self.t_pos_idx = individual_faces
self._v_tex = uv_flat
self._v_nrm = self._compute_vertex_normal()
self._v_tng = self._compute_vertex_tangent()
def _compute_edges(self):
# Compute edges
edges = torch.cat(
[
self.t_pos_idx[:, [0, 1]],
self.t_pos_idx[:, [1, 2]],
self.t_pos_idx[:, [2, 0]],
],
dim=0,
)
edges = edges.sort()[0]
edges = torch.unique(edges, dim=0)
return edges
|