Spaces:
Running
on
L40S
Running
on
L40S
Upload 5 files
Browse files- sf3d/models/sf3d_models_camera.py +32 -0
- sf3d/models/sf3d_models_isosurface.py +229 -0
- sf3d/models/sf3d_models_mesh.py +172 -0
- sf3d/models/sf3d_models_network.py +195 -0
- sf3d/models/sf3d_models_utils.py +292 -0
sf3d/models/sf3d_models_camera.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from sf3d.models.utils import BaseModule
|
8 |
+
|
9 |
+
|
10 |
+
class LinearCameraEmbedder(BaseModule):
|
11 |
+
@dataclass
|
12 |
+
class Config(BaseModule.Config):
|
13 |
+
in_channels: int = 25
|
14 |
+
out_channels: int = 768
|
15 |
+
conditions: List[str] = field(default_factory=list)
|
16 |
+
|
17 |
+
cfg: Config
|
18 |
+
|
19 |
+
def configure(self) -> None:
|
20 |
+
self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
|
21 |
+
|
22 |
+
def forward(self, **kwargs):
|
23 |
+
cond_tensors = []
|
24 |
+
for cond_name in self.cfg.conditions:
|
25 |
+
assert cond_name in kwargs
|
26 |
+
cond = kwargs[cond_name]
|
27 |
+
# cond in shape (B, Nv, ...)
|
28 |
+
cond_tensors.append(cond.view(*cond.shape[:2], -1))
|
29 |
+
cond_tensor = torch.cat(cond_tensors, dim=-1)
|
30 |
+
assert cond_tensor.shape[-1] == self.cfg.in_channels
|
31 |
+
embedding = self.linear(cond_tensor)
|
32 |
+
return embedding
|
sf3d/models/sf3d_models_isosurface.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from jaxtyping import Float, Integer
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from .mesh import Mesh
|
10 |
+
|
11 |
+
|
12 |
+
class IsosurfaceHelper(nn.Module):
|
13 |
+
points_range: Tuple[float, float] = (0, 1)
|
14 |
+
|
15 |
+
@property
|
16 |
+
def grid_vertices(self) -> Float[Tensor, "N 3"]:
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
@property
|
20 |
+
def requires_instance_per_batch(self) -> bool:
|
21 |
+
return False
|
22 |
+
|
23 |
+
|
24 |
+
class MarchingTetrahedraHelper(IsosurfaceHelper):
|
25 |
+
def __init__(self, resolution: int, tets_path: str):
|
26 |
+
super().__init__()
|
27 |
+
self.resolution = resolution
|
28 |
+
self.tets_path = tets_path
|
29 |
+
|
30 |
+
self.triangle_table: Float[Tensor, "..."]
|
31 |
+
self.register_buffer(
|
32 |
+
"triangle_table",
|
33 |
+
torch.as_tensor(
|
34 |
+
[
|
35 |
+
[-1, -1, -1, -1, -1, -1],
|
36 |
+
[1, 0, 2, -1, -1, -1],
|
37 |
+
[4, 0, 3, -1, -1, -1],
|
38 |
+
[1, 4, 2, 1, 3, 4],
|
39 |
+
[3, 1, 5, -1, -1, -1],
|
40 |
+
[2, 3, 0, 2, 5, 3],
|
41 |
+
[1, 4, 0, 1, 5, 4],
|
42 |
+
[4, 2, 5, -1, -1, -1],
|
43 |
+
[4, 5, 2, -1, -1, -1],
|
44 |
+
[4, 1, 0, 4, 5, 1],
|
45 |
+
[3, 2, 0, 3, 5, 2],
|
46 |
+
[1, 3, 5, -1, -1, -1],
|
47 |
+
[4, 1, 2, 4, 3, 1],
|
48 |
+
[3, 0, 4, -1, -1, -1],
|
49 |
+
[2, 0, 1, -1, -1, -1],
|
50 |
+
[-1, -1, -1, -1, -1, -1],
|
51 |
+
],
|
52 |
+
dtype=torch.long,
|
53 |
+
),
|
54 |
+
persistent=False,
|
55 |
+
)
|
56 |
+
self.num_triangles_table: Integer[Tensor, "..."]
|
57 |
+
self.register_buffer(
|
58 |
+
"num_triangles_table",
|
59 |
+
torch.as_tensor(
|
60 |
+
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
|
61 |
+
),
|
62 |
+
persistent=False,
|
63 |
+
)
|
64 |
+
self.base_tet_edges: Integer[Tensor, "..."]
|
65 |
+
self.register_buffer(
|
66 |
+
"base_tet_edges",
|
67 |
+
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
|
68 |
+
persistent=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
tets = np.load(self.tets_path)
|
72 |
+
self._grid_vertices: Float[Tensor, "..."]
|
73 |
+
self.register_buffer(
|
74 |
+
"_grid_vertices",
|
75 |
+
torch.from_numpy(tets["vertices"]).float(),
|
76 |
+
persistent=False,
|
77 |
+
)
|
78 |
+
self.indices: Integer[Tensor, "..."]
|
79 |
+
self.register_buffer(
|
80 |
+
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False
|
81 |
+
)
|
82 |
+
|
83 |
+
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
84 |
+
|
85 |
+
center_indices, boundary_indices = self.get_center_boundary_index(
|
86 |
+
self._grid_vertices
|
87 |
+
)
|
88 |
+
self.center_indices: Integer[Tensor, "..."]
|
89 |
+
self.register_buffer("center_indices", center_indices, persistent=False)
|
90 |
+
self.boundary_indices: Integer[Tensor, "..."]
|
91 |
+
self.register_buffer("boundary_indices", boundary_indices, persistent=False)
|
92 |
+
|
93 |
+
def get_center_boundary_index(self, verts):
|
94 |
+
magn = torch.sum(verts**2, dim=-1)
|
95 |
+
|
96 |
+
center_idx = torch.argmin(magn)
|
97 |
+
boundary_neg = verts == verts.max()
|
98 |
+
boundary_pos = verts == verts.min()
|
99 |
+
|
100 |
+
boundary = torch.bitwise_or(boundary_pos, boundary_neg)
|
101 |
+
boundary = torch.sum(boundary.float(), dim=-1)
|
102 |
+
|
103 |
+
boundary_idx = torch.nonzero(boundary)
|
104 |
+
return center_idx, boundary_idx.squeeze(dim=-1)
|
105 |
+
|
106 |
+
def normalize_grid_deformation(
|
107 |
+
self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
|
108 |
+
) -> Float[Tensor, "Nv 3"]:
|
109 |
+
return (
|
110 |
+
(self.points_range[1] - self.points_range[0])
|
111 |
+
/ self.resolution # half tet size is approximately 1 / self.resolution
|
112 |
+
* torch.tanh(grid_vertex_offsets)
|
113 |
+
) # FIXME: hard-coded activation
|
114 |
+
|
115 |
+
@property
|
116 |
+
def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
|
117 |
+
return self._grid_vertices
|
118 |
+
|
119 |
+
@property
|
120 |
+
def all_edges(self) -> Integer[Tensor, "Ne 2"]:
|
121 |
+
if self._all_edges is None:
|
122 |
+
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
|
123 |
+
edges = torch.tensor(
|
124 |
+
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
125 |
+
dtype=torch.long,
|
126 |
+
device=self.indices.device,
|
127 |
+
)
|
128 |
+
_all_edges = self.indices[:, edges].reshape(-1, 2)
|
129 |
+
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
|
130 |
+
_all_edges = torch.unique(_all_edges_sorted, dim=0)
|
131 |
+
self._all_edges = _all_edges
|
132 |
+
return self._all_edges
|
133 |
+
|
134 |
+
def sort_edges(self, edges_ex2):
|
135 |
+
with torch.no_grad():
|
136 |
+
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
137 |
+
order = order.unsqueeze(dim=1)
|
138 |
+
|
139 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
140 |
+
b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
|
141 |
+
|
142 |
+
return torch.stack([a, b], -1)
|
143 |
+
|
144 |
+
def _forward(self, pos_nx3, sdf_n, tet_fx4):
|
145 |
+
with torch.no_grad():
|
146 |
+
occ_n = sdf_n > 0
|
147 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
148 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
149 |
+
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
150 |
+
occ_sum = occ_sum[valid_tets]
|
151 |
+
|
152 |
+
# find all vertices
|
153 |
+
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
|
154 |
+
all_edges = self.sort_edges(all_edges)
|
155 |
+
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
|
156 |
+
|
157 |
+
unique_edges = unique_edges.long()
|
158 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
159 |
+
mapping = (
|
160 |
+
torch.ones(
|
161 |
+
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
|
162 |
+
)
|
163 |
+
* -1
|
164 |
+
)
|
165 |
+
mapping[mask_edges] = torch.arange(
|
166 |
+
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
|
167 |
+
)
|
168 |
+
idx_map = mapping[idx_map] # map edges to verts
|
169 |
+
|
170 |
+
interp_v = unique_edges[mask_edges]
|
171 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
172 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
173 |
+
edges_to_interp_sdf[:, -1] *= -1
|
174 |
+
|
175 |
+
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
176 |
+
|
177 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
|
178 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
179 |
+
|
180 |
+
idx_map = idx_map.reshape(-1, 6)
|
181 |
+
|
182 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
|
183 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
184 |
+
num_triangles = self.num_triangles_table[tetindex]
|
185 |
+
|
186 |
+
# Generate triangle indices
|
187 |
+
faces = torch.cat(
|
188 |
+
(
|
189 |
+
torch.gather(
|
190 |
+
input=idx_map[num_triangles == 1],
|
191 |
+
dim=1,
|
192 |
+
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
|
193 |
+
).reshape(-1, 3),
|
194 |
+
torch.gather(
|
195 |
+
input=idx_map[num_triangles == 2],
|
196 |
+
dim=1,
|
197 |
+
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
|
198 |
+
).reshape(-1, 3),
|
199 |
+
),
|
200 |
+
dim=0,
|
201 |
+
)
|
202 |
+
|
203 |
+
return verts, faces
|
204 |
+
|
205 |
+
def forward(
|
206 |
+
self,
|
207 |
+
level: Float[Tensor, "N3 1"],
|
208 |
+
deformation: Optional[Float[Tensor, "N3 3"]] = None,
|
209 |
+
) -> Mesh:
|
210 |
+
if deformation is not None:
|
211 |
+
grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
|
212 |
+
deformation
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
grid_vertices = self.grid_vertices
|
216 |
+
|
217 |
+
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
|
218 |
+
|
219 |
+
mesh = Mesh(
|
220 |
+
v_pos=v_pos,
|
221 |
+
t_pos_idx=t_pos_idx,
|
222 |
+
# extras
|
223 |
+
grid_vertices=grid_vertices,
|
224 |
+
tet_edges=self.all_edges,
|
225 |
+
grid_level=level,
|
226 |
+
grid_deformation=deformation,
|
227 |
+
)
|
228 |
+
|
229 |
+
return mesh
|
sf3d/models/sf3d_models_mesh.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from jaxtyping import Float, Integer
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from sf3d.box_uv_unwrap import box_projection_uv_unwrap
|
11 |
+
from sf3d.models.utils import dot
|
12 |
+
|
13 |
+
|
14 |
+
class Mesh:
|
15 |
+
def __init__(
|
16 |
+
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
|
17 |
+
) -> None:
|
18 |
+
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
19 |
+
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
20 |
+
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
21 |
+
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
22 |
+
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
23 |
+
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
24 |
+
self.extras: Dict[str, Any] = {}
|
25 |
+
for k, v in kwargs.items():
|
26 |
+
self.add_extra(k, v)
|
27 |
+
|
28 |
+
def add_extra(self, k, v) -> None:
|
29 |
+
self.extras[k] = v
|
30 |
+
|
31 |
+
@property
|
32 |
+
def requires_grad(self):
|
33 |
+
return self.v_pos.requires_grad
|
34 |
+
|
35 |
+
@property
|
36 |
+
def v_nrm(self):
|
37 |
+
if self._v_nrm is None:
|
38 |
+
self._v_nrm = self._compute_vertex_normal()
|
39 |
+
return self._v_nrm
|
40 |
+
|
41 |
+
@property
|
42 |
+
def v_tng(self):
|
43 |
+
if self._v_tng is None:
|
44 |
+
self._v_tng = self._compute_vertex_tangent()
|
45 |
+
return self._v_tng
|
46 |
+
|
47 |
+
@property
|
48 |
+
def v_tex(self):
|
49 |
+
if self._v_tex is None:
|
50 |
+
self.unwrap_uv()
|
51 |
+
return self._v_tex
|
52 |
+
|
53 |
+
@property
|
54 |
+
def edges(self):
|
55 |
+
if self._edges is None:
|
56 |
+
self._edges = self._compute_edges()
|
57 |
+
return self._edges
|
58 |
+
|
59 |
+
def _compute_vertex_normal(self):
|
60 |
+
i0 = self.t_pos_idx[:, 0]
|
61 |
+
i1 = self.t_pos_idx[:, 1]
|
62 |
+
i2 = self.t_pos_idx[:, 2]
|
63 |
+
|
64 |
+
v0 = self.v_pos[i0, :]
|
65 |
+
v1 = self.v_pos[i1, :]
|
66 |
+
v2 = self.v_pos[i2, :]
|
67 |
+
|
68 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
69 |
+
|
70 |
+
# Splat face normals to vertices
|
71 |
+
v_nrm = torch.zeros_like(self.v_pos)
|
72 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
73 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
74 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
75 |
+
|
76 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
77 |
+
v_nrm = torch.where(
|
78 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
79 |
+
)
|
80 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
81 |
+
|
82 |
+
if torch.is_anomaly_enabled():
|
83 |
+
assert torch.all(torch.isfinite(v_nrm))
|
84 |
+
|
85 |
+
return v_nrm
|
86 |
+
|
87 |
+
def _compute_vertex_tangent(self):
|
88 |
+
vn_idx = [None] * 3
|
89 |
+
pos = [None] * 3
|
90 |
+
tex = [None] * 3
|
91 |
+
for i in range(0, 3):
|
92 |
+
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
93 |
+
tex[i] = self.v_tex[self.t_pos_idx[:, i]]
|
94 |
+
# t_nrm_idx is always the same as t_pos_idx
|
95 |
+
vn_idx[i] = self.t_pos_idx[:, i]
|
96 |
+
|
97 |
+
tangents = torch.zeros_like(self.v_nrm)
|
98 |
+
tansum = torch.zeros_like(self.v_nrm)
|
99 |
+
|
100 |
+
# Compute tangent space for each triangle
|
101 |
+
duv1 = tex[1] - tex[0]
|
102 |
+
duv2 = tex[2] - tex[0]
|
103 |
+
dpos1 = pos[1] - pos[0]
|
104 |
+
dpos2 = pos[2] - pos[0]
|
105 |
+
|
106 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
107 |
+
|
108 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
109 |
+
|
110 |
+
# Avoid division by zero for degenerated texture coordinates
|
111 |
+
denom_safe = denom.clip(1e-6)
|
112 |
+
tang = tng_nom / denom_safe
|
113 |
+
|
114 |
+
# Update all 3 vertices
|
115 |
+
for i in range(0, 3):
|
116 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
117 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
118 |
+
tansum.scatter_add_(
|
119 |
+
0, idx, torch.ones_like(tang)
|
120 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
121 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
122 |
+
# triangles influence the tangent space more
|
123 |
+
tangents = tangents / tansum
|
124 |
+
|
125 |
+
# Normalize and make sure tangent is perpendicular to normal
|
126 |
+
tangents = F.normalize(tangents, dim=1)
|
127 |
+
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
128 |
+
|
129 |
+
if torch.is_anomaly_enabled():
|
130 |
+
assert torch.all(torch.isfinite(tangents))
|
131 |
+
|
132 |
+
return tangents
|
133 |
+
|
134 |
+
@torch.no_grad()
|
135 |
+
def unwrap_uv(
|
136 |
+
self,
|
137 |
+
island_padding: float = 0.02,
|
138 |
+
) -> Mesh:
|
139 |
+
uv, indices = box_projection_uv_unwrap(
|
140 |
+
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
|
141 |
+
)
|
142 |
+
|
143 |
+
# Do store per vertex UVs.
|
144 |
+
# This means we need to duplicate some vertices at the seams
|
145 |
+
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
|
146 |
+
individual_faces = torch.arange(
|
147 |
+
individual_vertices.shape[0],
|
148 |
+
device=individual_vertices.device,
|
149 |
+
dtype=self.t_pos_idx.dtype,
|
150 |
+
).reshape(-1, 3)
|
151 |
+
uv_flat = uv[indices].reshape((-1, 2))
|
152 |
+
# uv_flat[:, 1] = 1 - uv_flat[:, 1]
|
153 |
+
|
154 |
+
self.v_pos = individual_vertices
|
155 |
+
self.t_pos_idx = individual_faces
|
156 |
+
self._v_tex = uv_flat
|
157 |
+
self._v_nrm = self._compute_vertex_normal()
|
158 |
+
self._v_tng = self._compute_vertex_tangent()
|
159 |
+
|
160 |
+
def _compute_edges(self):
|
161 |
+
# Compute edges
|
162 |
+
edges = torch.cat(
|
163 |
+
[
|
164 |
+
self.t_pos_idx[:, [0, 1]],
|
165 |
+
self.t_pos_idx[:, [1, 2]],
|
166 |
+
self.t_pos_idx[:, [2, 0]],
|
167 |
+
],
|
168 |
+
dim=0,
|
169 |
+
)
|
170 |
+
edges = edges.sort()[0]
|
171 |
+
edges = torch.unique(edges, dim=0)
|
172 |
+
return edges
|
sf3d/models/sf3d_models_network.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from jaxtyping import Float
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.autograd import Function
|
11 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
12 |
+
|
13 |
+
from sf3d.models.utils import BaseModule, normalize
|
14 |
+
|
15 |
+
|
16 |
+
class PixelShuffleUpsampleNetwork(BaseModule):
|
17 |
+
@dataclass
|
18 |
+
class Config(BaseModule.Config):
|
19 |
+
in_channels: int = 1024
|
20 |
+
out_channels: int = 40
|
21 |
+
scale_factor: int = 4
|
22 |
+
|
23 |
+
conv_layers: int = 4
|
24 |
+
conv_kernel_size: int = 3
|
25 |
+
|
26 |
+
cfg: Config
|
27 |
+
|
28 |
+
def configure(self) -> None:
|
29 |
+
layers = []
|
30 |
+
output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
|
31 |
+
|
32 |
+
in_channels = self.cfg.in_channels
|
33 |
+
for i in range(self.cfg.conv_layers):
|
34 |
+
cur_out_channels = (
|
35 |
+
in_channels if i != self.cfg.conv_layers - 1 else output_channels
|
36 |
+
)
|
37 |
+
layers.append(
|
38 |
+
nn.Conv2d(
|
39 |
+
in_channels,
|
40 |
+
cur_out_channels,
|
41 |
+
self.cfg.conv_kernel_size,
|
42 |
+
padding=(self.cfg.conv_kernel_size - 1) // 2,
|
43 |
+
)
|
44 |
+
)
|
45 |
+
if i != self.cfg.conv_layers - 1:
|
46 |
+
layers.append(nn.ReLU(inplace=True))
|
47 |
+
|
48 |
+
layers.append(nn.PixelShuffle(self.cfg.scale_factor))
|
49 |
+
|
50 |
+
self.upsample = nn.Sequential(*layers)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
|
54 |
+
) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
|
55 |
+
return rearrange(
|
56 |
+
self.upsample(
|
57 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
58 |
+
),
|
59 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
60 |
+
Np=3,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class _TruncExp(Function): # pylint: disable=abstract-method
|
65 |
+
# Implementation from torch-ngp:
|
66 |
+
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
67 |
+
@staticmethod
|
68 |
+
@custom_fwd(cast_inputs=torch.float32)
|
69 |
+
def forward(ctx, x): # pylint: disable=arguments-differ
|
70 |
+
ctx.save_for_backward(x)
|
71 |
+
return torch.exp(x)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
@custom_bwd
|
75 |
+
def backward(ctx, g): # pylint: disable=arguments-differ
|
76 |
+
x = ctx.saved_tensors[0]
|
77 |
+
return g * torch.exp(torch.clamp(x, max=15))
|
78 |
+
|
79 |
+
|
80 |
+
trunc_exp = _TruncExp.apply
|
81 |
+
|
82 |
+
|
83 |
+
def get_activation(name) -> Callable:
|
84 |
+
if name is None:
|
85 |
+
return lambda x: x
|
86 |
+
name = name.lower()
|
87 |
+
if name == "none" or name == "linear" or name == "identity":
|
88 |
+
return lambda x: x
|
89 |
+
elif name == "lin2srgb":
|
90 |
+
return lambda x: torch.where(
|
91 |
+
x > 0.0031308,
|
92 |
+
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
93 |
+
12.92 * x,
|
94 |
+
).clamp(0.0, 1.0)
|
95 |
+
elif name == "exp":
|
96 |
+
return lambda x: torch.exp(x)
|
97 |
+
elif name == "shifted_exp":
|
98 |
+
return lambda x: torch.exp(x - 1.0)
|
99 |
+
elif name == "trunc_exp":
|
100 |
+
return trunc_exp
|
101 |
+
elif name == "shifted_trunc_exp":
|
102 |
+
return lambda x: trunc_exp(x - 1.0)
|
103 |
+
elif name == "sigmoid":
|
104 |
+
return lambda x: torch.sigmoid(x)
|
105 |
+
elif name == "tanh":
|
106 |
+
return lambda x: torch.tanh(x)
|
107 |
+
elif name == "shifted_softplus":
|
108 |
+
return lambda x: F.softplus(x - 1.0)
|
109 |
+
elif name == "scale_-11_01":
|
110 |
+
return lambda x: x * 0.5 + 0.5
|
111 |
+
elif name == "negative":
|
112 |
+
return lambda x: -x
|
113 |
+
elif name == "normalize_channel_last":
|
114 |
+
return lambda x: normalize(x)
|
115 |
+
elif name == "normalize_channel_first":
|
116 |
+
return lambda x: normalize(x, dim=1)
|
117 |
+
else:
|
118 |
+
try:
|
119 |
+
return getattr(F, name)
|
120 |
+
except AttributeError:
|
121 |
+
raise ValueError(f"Unknown activation function: {name}")
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class HeadSpec:
|
126 |
+
name: str
|
127 |
+
out_channels: int
|
128 |
+
n_hidden_layers: int
|
129 |
+
output_activation: Optional[str] = None
|
130 |
+
out_bias: float = 0.0
|
131 |
+
|
132 |
+
|
133 |
+
class MaterialMLP(BaseModule):
|
134 |
+
@dataclass
|
135 |
+
class Config(BaseModule.Config):
|
136 |
+
in_channels: int = 120
|
137 |
+
n_neurons: int = 64
|
138 |
+
activation: str = "silu"
|
139 |
+
heads: List[HeadSpec] = field(default_factory=lambda: [])
|
140 |
+
|
141 |
+
cfg: Config
|
142 |
+
|
143 |
+
def configure(self) -> None:
|
144 |
+
assert len(self.cfg.heads) > 0
|
145 |
+
heads = {}
|
146 |
+
for head in self.cfg.heads:
|
147 |
+
head_layers = []
|
148 |
+
for i in range(head.n_hidden_layers):
|
149 |
+
head_layers += [
|
150 |
+
nn.Linear(
|
151 |
+
self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
|
152 |
+
self.cfg.n_neurons,
|
153 |
+
),
|
154 |
+
self.make_activation(self.cfg.activation),
|
155 |
+
]
|
156 |
+
head_layers += [
|
157 |
+
nn.Linear(
|
158 |
+
self.cfg.n_neurons,
|
159 |
+
head.out_channels,
|
160 |
+
),
|
161 |
+
]
|
162 |
+
heads[head.name] = nn.Sequential(*head_layers)
|
163 |
+
self.heads = nn.ModuleDict(heads)
|
164 |
+
|
165 |
+
def make_activation(self, activation):
|
166 |
+
if activation == "relu":
|
167 |
+
return nn.ReLU(inplace=True)
|
168 |
+
elif activation == "silu":
|
169 |
+
return nn.SiLU(inplace=True)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError
|
172 |
+
|
173 |
+
def keys(self):
|
174 |
+
return self.heads.keys()
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self, x, include: Optional[List] = None, exclude: Optional[List] = None
|
178 |
+
):
|
179 |
+
if include is not None and exclude is not None:
|
180 |
+
raise ValueError("Cannot specify both include and exclude.")
|
181 |
+
if include is not None:
|
182 |
+
heads = [h for h in self.cfg.heads if h.name in include]
|
183 |
+
elif exclude is not None:
|
184 |
+
heads = [h for h in self.cfg.heads if h.name not in exclude]
|
185 |
+
else:
|
186 |
+
heads = self.cfg.heads
|
187 |
+
|
188 |
+
out = {
|
189 |
+
head.name: get_activation(head.output_activation)(
|
190 |
+
self.heads[head.name](x) + head.out_bias
|
191 |
+
)
|
192 |
+
for head in heads
|
193 |
+
}
|
194 |
+
|
195 |
+
return out
|
sf3d/models/sf3d_models_utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import importlib
|
3 |
+
import math
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Any, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import PIL
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from jaxtyping import Bool, Float, Int, Num
|
13 |
+
from omegaconf import DictConfig, OmegaConf
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
|
17 |
+
class BaseModule(nn.Module):
|
18 |
+
@dataclass
|
19 |
+
class Config:
|
20 |
+
pass
|
21 |
+
|
22 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self.cfg = parse_structured(self.Config, cfg)
|
29 |
+
self.configure(*args, **kwargs)
|
30 |
+
|
31 |
+
def configure(self, *args, **kwargs) -> None:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
def find_class(cls_string):
|
36 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
37 |
+
cls_name = cls_string.split(".")[-1]
|
38 |
+
module = importlib.import_module(module_string, package=None)
|
39 |
+
cls = getattr(module, cls_name)
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
44 |
+
# Check if cfg.keys are in fields
|
45 |
+
cfg_ = cfg.copy()
|
46 |
+
keys = list(cfg_.keys())
|
47 |
+
|
48 |
+
field_names = {f.name for f in dataclasses.fields(fields)}
|
49 |
+
for key in keys:
|
50 |
+
# This is helpful when swapping out modules from CLI
|
51 |
+
if key not in field_names:
|
52 |
+
print(f"Ignoring {key} as it's not supported by {fields}")
|
53 |
+
cfg_.pop(key)
|
54 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
|
55 |
+
return scfg
|
56 |
+
|
57 |
+
|
58 |
+
EPS_DTYPE = {
|
59 |
+
torch.float16: 1e-4,
|
60 |
+
torch.bfloat16: 1e-4,
|
61 |
+
torch.float32: 1e-7,
|
62 |
+
torch.float64: 1e-8,
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def dot(x, y, dim=-1):
|
67 |
+
return torch.sum(x * y, dim, keepdim=True)
|
68 |
+
|
69 |
+
|
70 |
+
def reflect(x, n):
|
71 |
+
return x - 2 * dot(x, n) * n
|
72 |
+
|
73 |
+
|
74 |
+
def normalize(x, dim=-1, eps=None):
|
75 |
+
if eps is None:
|
76 |
+
eps = EPS_DTYPE[x.dtype]
|
77 |
+
return F.normalize(x, dim=dim, p=2, eps=eps)
|
78 |
+
|
79 |
+
|
80 |
+
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
|
81 |
+
# One pad for determinant
|
82 |
+
tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
|
83 |
+
det_tri = torch.det(tri_sq)
|
84 |
+
tri_rev = torch.cat(
|
85 |
+
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
|
86 |
+
)
|
87 |
+
tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
|
88 |
+
return tri_sq
|
89 |
+
|
90 |
+
|
91 |
+
def triangle_intersection_2d(
|
92 |
+
t1: Float[Tensor, "*B 3 2"],
|
93 |
+
t2: Float[Tensor, "*B 3 2"],
|
94 |
+
eps=1e-12,
|
95 |
+
) -> Float[Tensor, "*B"]: # noqa: F821
|
96 |
+
"""Returns True if triangles collide, False otherwise"""
|
97 |
+
|
98 |
+
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
|
99 |
+
logdetx = torch.logdet(x.double())
|
100 |
+
if eps is None:
|
101 |
+
return ~torch.isfinite(logdetx)
|
102 |
+
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
|
103 |
+
|
104 |
+
t1s = tri_winding(t1)
|
105 |
+
t2s = tri_winding(t2)
|
106 |
+
|
107 |
+
# Assume the triangles do not collide in the begging
|
108 |
+
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
|
109 |
+
for i in range(3):
|
110 |
+
edge = torch.roll(t1s, i, dims=1)[:, :2, :]
|
111 |
+
# Check if all points of triangle 2 lay on the external side of edge E.
|
112 |
+
# If this is the case the triangle do not collide
|
113 |
+
upd = (
|
114 |
+
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
|
115 |
+
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
|
116 |
+
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
|
117 |
+
)
|
118 |
+
# Here no collision is still True due to inversion
|
119 |
+
ret = ret | upd
|
120 |
+
|
121 |
+
for i in range(3):
|
122 |
+
edge = torch.roll(t2s, i, dims=1)[:, :2, :]
|
123 |
+
|
124 |
+
upd = (
|
125 |
+
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
|
126 |
+
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
|
127 |
+
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
|
128 |
+
)
|
129 |
+
# Here no collision is still True due to inversion
|
130 |
+
ret = ret | upd
|
131 |
+
|
132 |
+
return ~ret # Do the inversion
|
133 |
+
|
134 |
+
|
135 |
+
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
136 |
+
|
137 |
+
|
138 |
+
def scale_tensor(
|
139 |
+
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
140 |
+
):
|
141 |
+
if inp_scale is None:
|
142 |
+
inp_scale = (0, 1)
|
143 |
+
if tgt_scale is None:
|
144 |
+
tgt_scale = (0, 1)
|
145 |
+
if isinstance(tgt_scale, Tensor):
|
146 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
147 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
148 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
149 |
+
return dat
|
150 |
+
|
151 |
+
|
152 |
+
def dilate_fill(img, mask, iterations=10):
|
153 |
+
oldMask = mask.float()
|
154 |
+
oldImg = img
|
155 |
+
|
156 |
+
mask_kernel = torch.ones(
|
157 |
+
(1, 1, 3, 3),
|
158 |
+
dtype=oldMask.dtype,
|
159 |
+
device=oldMask.device,
|
160 |
+
)
|
161 |
+
|
162 |
+
for i in range(iterations):
|
163 |
+
newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
|
164 |
+
|
165 |
+
# Fill the extension with mean color of old valid regions
|
166 |
+
img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
|
167 |
+
mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
168 |
+
new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
|
169 |
+
|
170 |
+
# Average color of the valid region
|
171 |
+
mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
|
172 |
+
2
|
173 |
+
)
|
174 |
+
# Extend it to the new region
|
175 |
+
fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
|
176 |
+
|
177 |
+
mask_conv = F.conv2d(
|
178 |
+
newMask, mask_kernel, padding=1
|
179 |
+
) # Get the sum for each kernel patch
|
180 |
+
newImg = F.fold(
|
181 |
+
fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
|
182 |
+
) / mask_conv.clamp(1)
|
183 |
+
|
184 |
+
diffMask = newMask - oldMask
|
185 |
+
|
186 |
+
oldMask = newMask
|
187 |
+
oldImg = torch.lerp(oldImg, newImg, diffMask)
|
188 |
+
|
189 |
+
return oldImg
|
190 |
+
|
191 |
+
|
192 |
+
def float32_to_uint8_np(
|
193 |
+
x: Float[np.ndarray, "*B H W C"],
|
194 |
+
dither: bool = True,
|
195 |
+
dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
|
196 |
+
dither_strength: float = 1.0,
|
197 |
+
) -> Int[np.ndarray, "*B H W C"]:
|
198 |
+
if dither:
|
199 |
+
dither = (
|
200 |
+
dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
|
201 |
+
)
|
202 |
+
if dither_mask is not None:
|
203 |
+
dither = dither * dither_mask
|
204 |
+
return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
|
205 |
+
return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
|
206 |
+
|
207 |
+
|
208 |
+
def convert_data(data):
|
209 |
+
if data is None:
|
210 |
+
return None
|
211 |
+
elif isinstance(data, np.ndarray):
|
212 |
+
return data
|
213 |
+
elif isinstance(data, torch.Tensor):
|
214 |
+
if data.dtype in [torch.float16, torch.bfloat16]:
|
215 |
+
data = data.float()
|
216 |
+
return data.detach().cpu().numpy()
|
217 |
+
elif isinstance(data, list):
|
218 |
+
return [convert_data(d) for d in data]
|
219 |
+
elif isinstance(data, dict):
|
220 |
+
return {k: convert_data(v) for k, v in data.items()}
|
221 |
+
else:
|
222 |
+
raise TypeError(
|
223 |
+
"Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
|
224 |
+
type(data),
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
class ImageProcessor:
|
229 |
+
def convert_and_resize(
|
230 |
+
self,
|
231 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
232 |
+
size: int,
|
233 |
+
):
|
234 |
+
if isinstance(image, PIL.Image.Image):
|
235 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
236 |
+
elif isinstance(image, np.ndarray):
|
237 |
+
if image.dtype == np.uint8:
|
238 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
239 |
+
else:
|
240 |
+
image = torch.from_numpy(image)
|
241 |
+
elif isinstance(image, torch.Tensor):
|
242 |
+
pass
|
243 |
+
|
244 |
+
batched = image.ndim == 4
|
245 |
+
|
246 |
+
if not batched:
|
247 |
+
image = image[None, ...]
|
248 |
+
image = F.interpolate(
|
249 |
+
image.permute(0, 3, 1, 2),
|
250 |
+
(size, size),
|
251 |
+
mode="bilinear",
|
252 |
+
align_corners=False,
|
253 |
+
antialias=True,
|
254 |
+
).permute(0, 2, 3, 1)
|
255 |
+
if not batched:
|
256 |
+
image = image[0]
|
257 |
+
return image
|
258 |
+
|
259 |
+
def __call__(
|
260 |
+
self,
|
261 |
+
image: Union[
|
262 |
+
PIL.Image.Image,
|
263 |
+
np.ndarray,
|
264 |
+
torch.FloatTensor,
|
265 |
+
List[PIL.Image.Image],
|
266 |
+
List[np.ndarray],
|
267 |
+
List[torch.FloatTensor],
|
268 |
+
],
|
269 |
+
size: int,
|
270 |
+
) -> Any:
|
271 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
272 |
+
image = self.convert_and_resize(image, size)
|
273 |
+
else:
|
274 |
+
if not isinstance(image, list):
|
275 |
+
image = [image]
|
276 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
277 |
+
image = torch.stack(image, dim=0)
|
278 |
+
return image
|
279 |
+
|
280 |
+
|
281 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
282 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
283 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
284 |
+
intrinsic[0, 0] = focal_length
|
285 |
+
intrinsic[1, 1] = focal_length
|
286 |
+
intrinsic[0, 2] = W / 2.0
|
287 |
+
intrinsic[1, 2] = H / 2.0
|
288 |
+
|
289 |
+
if bs > 0:
|
290 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
291 |
+
|
292 |
+
return torch.from_numpy(intrinsic)
|