aiqtech commited on
Commit
c371ec2
1 Parent(s): 4191268

Upload 5 files

Browse files
sf3d/models/sf3d_models_camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sf3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
sf3d/models/sf3d_models_isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
sf3d/models/sf3d_models_mesh.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from jaxtyping import Float, Integer
8
+ from torch import Tensor
9
+
10
+ from sf3d.box_uv_unwrap import box_projection_uv_unwrap
11
+ from sf3d.models.utils import dot
12
+
13
+
14
+ class Mesh:
15
+ def __init__(
16
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
17
+ ) -> None:
18
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
19
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
20
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
21
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
22
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
23
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
24
+ self.extras: Dict[str, Any] = {}
25
+ for k, v in kwargs.items():
26
+ self.add_extra(k, v)
27
+
28
+ def add_extra(self, k, v) -> None:
29
+ self.extras[k] = v
30
+
31
+ @property
32
+ def requires_grad(self):
33
+ return self.v_pos.requires_grad
34
+
35
+ @property
36
+ def v_nrm(self):
37
+ if self._v_nrm is None:
38
+ self._v_nrm = self._compute_vertex_normal()
39
+ return self._v_nrm
40
+
41
+ @property
42
+ def v_tng(self):
43
+ if self._v_tng is None:
44
+ self._v_tng = self._compute_vertex_tangent()
45
+ return self._v_tng
46
+
47
+ @property
48
+ def v_tex(self):
49
+ if self._v_tex is None:
50
+ self.unwrap_uv()
51
+ return self._v_tex
52
+
53
+ @property
54
+ def edges(self):
55
+ if self._edges is None:
56
+ self._edges = self._compute_edges()
57
+ return self._edges
58
+
59
+ def _compute_vertex_normal(self):
60
+ i0 = self.t_pos_idx[:, 0]
61
+ i1 = self.t_pos_idx[:, 1]
62
+ i2 = self.t_pos_idx[:, 2]
63
+
64
+ v0 = self.v_pos[i0, :]
65
+ v1 = self.v_pos[i1, :]
66
+ v2 = self.v_pos[i2, :]
67
+
68
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
69
+
70
+ # Splat face normals to vertices
71
+ v_nrm = torch.zeros_like(self.v_pos)
72
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
73
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
74
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
75
+
76
+ # Normalize, replace zero (degenerated) normals with some default value
77
+ v_nrm = torch.where(
78
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
79
+ )
80
+ v_nrm = F.normalize(v_nrm, dim=1)
81
+
82
+ if torch.is_anomaly_enabled():
83
+ assert torch.all(torch.isfinite(v_nrm))
84
+
85
+ return v_nrm
86
+
87
+ def _compute_vertex_tangent(self):
88
+ vn_idx = [None] * 3
89
+ pos = [None] * 3
90
+ tex = [None] * 3
91
+ for i in range(0, 3):
92
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
93
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
94
+ # t_nrm_idx is always the same as t_pos_idx
95
+ vn_idx[i] = self.t_pos_idx[:, i]
96
+
97
+ tangents = torch.zeros_like(self.v_nrm)
98
+ tansum = torch.zeros_like(self.v_nrm)
99
+
100
+ # Compute tangent space for each triangle
101
+ duv1 = tex[1] - tex[0]
102
+ duv2 = tex[2] - tex[0]
103
+ dpos1 = pos[1] - pos[0]
104
+ dpos2 = pos[2] - pos[0]
105
+
106
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
107
+
108
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
109
+
110
+ # Avoid division by zero for degenerated texture coordinates
111
+ denom_safe = denom.clip(1e-6)
112
+ tang = tng_nom / denom_safe
113
+
114
+ # Update all 3 vertices
115
+ for i in range(0, 3):
116
+ idx = vn_idx[i][:, None].repeat(1, 3)
117
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
118
+ tansum.scatter_add_(
119
+ 0, idx, torch.ones_like(tang)
120
+ ) # tansum[n_i] = tansum[n_i] + 1
121
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
122
+ # triangles influence the tangent space more
123
+ tangents = tangents / tansum
124
+
125
+ # Normalize and make sure tangent is perpendicular to normal
126
+ tangents = F.normalize(tangents, dim=1)
127
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
128
+
129
+ if torch.is_anomaly_enabled():
130
+ assert torch.all(torch.isfinite(tangents))
131
+
132
+ return tangents
133
+
134
+ @torch.no_grad()
135
+ def unwrap_uv(
136
+ self,
137
+ island_padding: float = 0.02,
138
+ ) -> Mesh:
139
+ uv, indices = box_projection_uv_unwrap(
140
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
141
+ )
142
+
143
+ # Do store per vertex UVs.
144
+ # This means we need to duplicate some vertices at the seams
145
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
146
+ individual_faces = torch.arange(
147
+ individual_vertices.shape[0],
148
+ device=individual_vertices.device,
149
+ dtype=self.t_pos_idx.dtype,
150
+ ).reshape(-1, 3)
151
+ uv_flat = uv[indices].reshape((-1, 2))
152
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
153
+
154
+ self.v_pos = individual_vertices
155
+ self.t_pos_idx = individual_faces
156
+ self._v_tex = uv_flat
157
+ self._v_nrm = self._compute_vertex_normal()
158
+ self._v_tng = self._compute_vertex_tangent()
159
+
160
+ def _compute_edges(self):
161
+ # Compute edges
162
+ edges = torch.cat(
163
+ [
164
+ self.t_pos_idx[:, [0, 1]],
165
+ self.t_pos_idx[:, [1, 2]],
166
+ self.t_pos_idx[:, [2, 0]],
167
+ ],
168
+ dim=0,
169
+ )
170
+ edges = edges.sort()[0]
171
+ edges = torch.unique(edges, dim=0)
172
+ return edges
sf3d/models/sf3d_models_network.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.autograd import Function
11
+ from torch.cuda.amp import custom_bwd, custom_fwd
12
+
13
+ from sf3d.models.utils import BaseModule, normalize
14
+
15
+
16
+ class PixelShuffleUpsampleNetwork(BaseModule):
17
+ @dataclass
18
+ class Config(BaseModule.Config):
19
+ in_channels: int = 1024
20
+ out_channels: int = 40
21
+ scale_factor: int = 4
22
+
23
+ conv_layers: int = 4
24
+ conv_kernel_size: int = 3
25
+
26
+ cfg: Config
27
+
28
+ def configure(self) -> None:
29
+ layers = []
30
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
31
+
32
+ in_channels = self.cfg.in_channels
33
+ for i in range(self.cfg.conv_layers):
34
+ cur_out_channels = (
35
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
36
+ )
37
+ layers.append(
38
+ nn.Conv2d(
39
+ in_channels,
40
+ cur_out_channels,
41
+ self.cfg.conv_kernel_size,
42
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
43
+ )
44
+ )
45
+ if i != self.cfg.conv_layers - 1:
46
+ layers.append(nn.ReLU(inplace=True))
47
+
48
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
49
+
50
+ self.upsample = nn.Sequential(*layers)
51
+
52
+ def forward(
53
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
54
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
55
+ return rearrange(
56
+ self.upsample(
57
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
58
+ ),
59
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
60
+ Np=3,
61
+ )
62
+
63
+
64
+ class _TruncExp(Function): # pylint: disable=abstract-method
65
+ # Implementation from torch-ngp:
66
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
67
+ @staticmethod
68
+ @custom_fwd(cast_inputs=torch.float32)
69
+ def forward(ctx, x): # pylint: disable=arguments-differ
70
+ ctx.save_for_backward(x)
71
+ return torch.exp(x)
72
+
73
+ @staticmethod
74
+ @custom_bwd
75
+ def backward(ctx, g): # pylint: disable=arguments-differ
76
+ x = ctx.saved_tensors[0]
77
+ return g * torch.exp(torch.clamp(x, max=15))
78
+
79
+
80
+ trunc_exp = _TruncExp.apply
81
+
82
+
83
+ def get_activation(name) -> Callable:
84
+ if name is None:
85
+ return lambda x: x
86
+ name = name.lower()
87
+ if name == "none" or name == "linear" or name == "identity":
88
+ return lambda x: x
89
+ elif name == "lin2srgb":
90
+ return lambda x: torch.where(
91
+ x > 0.0031308,
92
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
93
+ 12.92 * x,
94
+ ).clamp(0.0, 1.0)
95
+ elif name == "exp":
96
+ return lambda x: torch.exp(x)
97
+ elif name == "shifted_exp":
98
+ return lambda x: torch.exp(x - 1.0)
99
+ elif name == "trunc_exp":
100
+ return trunc_exp
101
+ elif name == "shifted_trunc_exp":
102
+ return lambda x: trunc_exp(x - 1.0)
103
+ elif name == "sigmoid":
104
+ return lambda x: torch.sigmoid(x)
105
+ elif name == "tanh":
106
+ return lambda x: torch.tanh(x)
107
+ elif name == "shifted_softplus":
108
+ return lambda x: F.softplus(x - 1.0)
109
+ elif name == "scale_-11_01":
110
+ return lambda x: x * 0.5 + 0.5
111
+ elif name == "negative":
112
+ return lambda x: -x
113
+ elif name == "normalize_channel_last":
114
+ return lambda x: normalize(x)
115
+ elif name == "normalize_channel_first":
116
+ return lambda x: normalize(x, dim=1)
117
+ else:
118
+ try:
119
+ return getattr(F, name)
120
+ except AttributeError:
121
+ raise ValueError(f"Unknown activation function: {name}")
122
+
123
+
124
+ @dataclass
125
+ class HeadSpec:
126
+ name: str
127
+ out_channels: int
128
+ n_hidden_layers: int
129
+ output_activation: Optional[str] = None
130
+ out_bias: float = 0.0
131
+
132
+
133
+ class MaterialMLP(BaseModule):
134
+ @dataclass
135
+ class Config(BaseModule.Config):
136
+ in_channels: int = 120
137
+ n_neurons: int = 64
138
+ activation: str = "silu"
139
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
140
+
141
+ cfg: Config
142
+
143
+ def configure(self) -> None:
144
+ assert len(self.cfg.heads) > 0
145
+ heads = {}
146
+ for head in self.cfg.heads:
147
+ head_layers = []
148
+ for i in range(head.n_hidden_layers):
149
+ head_layers += [
150
+ nn.Linear(
151
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
152
+ self.cfg.n_neurons,
153
+ ),
154
+ self.make_activation(self.cfg.activation),
155
+ ]
156
+ head_layers += [
157
+ nn.Linear(
158
+ self.cfg.n_neurons,
159
+ head.out_channels,
160
+ ),
161
+ ]
162
+ heads[head.name] = nn.Sequential(*head_layers)
163
+ self.heads = nn.ModuleDict(heads)
164
+
165
+ def make_activation(self, activation):
166
+ if activation == "relu":
167
+ return nn.ReLU(inplace=True)
168
+ elif activation == "silu":
169
+ return nn.SiLU(inplace=True)
170
+ else:
171
+ raise NotImplementedError
172
+
173
+ def keys(self):
174
+ return self.heads.keys()
175
+
176
+ def forward(
177
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
178
+ ):
179
+ if include is not None and exclude is not None:
180
+ raise ValueError("Cannot specify both include and exclude.")
181
+ if include is not None:
182
+ heads = [h for h in self.cfg.heads if h.name in include]
183
+ elif exclude is not None:
184
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
185
+ else:
186
+ heads = self.cfg.heads
187
+
188
+ out = {
189
+ head.name: get_activation(head.output_activation)(
190
+ self.heads[head.name](x) + head.out_bias
191
+ )
192
+ for head in heads
193
+ }
194
+
195
+ return out
sf3d/models/sf3d_models_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import importlib
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from jaxtyping import Bool, Float, Int, Num
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch import Tensor
15
+
16
+
17
+ class BaseModule(nn.Module):
18
+ @dataclass
19
+ class Config:
20
+ pass
21
+
22
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
23
+
24
+ def __init__(
25
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
26
+ ) -> None:
27
+ super().__init__()
28
+ self.cfg = parse_structured(self.Config, cfg)
29
+ self.configure(*args, **kwargs)
30
+
31
+ def configure(self, *args, **kwargs) -> None:
32
+ raise NotImplementedError
33
+
34
+
35
+ def find_class(cls_string):
36
+ module_string = ".".join(cls_string.split(".")[:-1])
37
+ cls_name = cls_string.split(".")[-1]
38
+ module = importlib.import_module(module_string, package=None)
39
+ cls = getattr(module, cls_name)
40
+ return cls
41
+
42
+
43
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
44
+ # Check if cfg.keys are in fields
45
+ cfg_ = cfg.copy()
46
+ keys = list(cfg_.keys())
47
+
48
+ field_names = {f.name for f in dataclasses.fields(fields)}
49
+ for key in keys:
50
+ # This is helpful when swapping out modules from CLI
51
+ if key not in field_names:
52
+ print(f"Ignoring {key} as it's not supported by {fields}")
53
+ cfg_.pop(key)
54
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
55
+ return scfg
56
+
57
+
58
+ EPS_DTYPE = {
59
+ torch.float16: 1e-4,
60
+ torch.bfloat16: 1e-4,
61
+ torch.float32: 1e-7,
62
+ torch.float64: 1e-8,
63
+ }
64
+
65
+
66
+ def dot(x, y, dim=-1):
67
+ return torch.sum(x * y, dim, keepdim=True)
68
+
69
+
70
+ def reflect(x, n):
71
+ return x - 2 * dot(x, n) * n
72
+
73
+
74
+ def normalize(x, dim=-1, eps=None):
75
+ if eps is None:
76
+ eps = EPS_DTYPE[x.dtype]
77
+ return F.normalize(x, dim=dim, p=2, eps=eps)
78
+
79
+
80
+ def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
81
+ # One pad for determinant
82
+ tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
83
+ det_tri = torch.det(tri_sq)
84
+ tri_rev = torch.cat(
85
+ (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
86
+ )
87
+ tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
88
+ return tri_sq
89
+
90
+
91
+ def triangle_intersection_2d(
92
+ t1: Float[Tensor, "*B 3 2"],
93
+ t2: Float[Tensor, "*B 3 2"],
94
+ eps=1e-12,
95
+ ) -> Float[Tensor, "*B"]: # noqa: F821
96
+ """Returns True if triangles collide, False otherwise"""
97
+
98
+ def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
99
+ logdetx = torch.logdet(x.double())
100
+ if eps is None:
101
+ return ~torch.isfinite(logdetx)
102
+ return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
103
+
104
+ t1s = tri_winding(t1)
105
+ t2s = tri_winding(t2)
106
+
107
+ # Assume the triangles do not collide in the begging
108
+ ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
109
+ for i in range(3):
110
+ edge = torch.roll(t1s, i, dims=1)[:, :2, :]
111
+ # Check if all points of triangle 2 lay on the external side of edge E.
112
+ # If this is the case the triangle do not collide
113
+ upd = (
114
+ chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
115
+ & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
116
+ & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
117
+ )
118
+ # Here no collision is still True due to inversion
119
+ ret = ret | upd
120
+
121
+ for i in range(3):
122
+ edge = torch.roll(t2s, i, dims=1)[:, :2, :]
123
+
124
+ upd = (
125
+ chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
126
+ & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
127
+ & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
128
+ )
129
+ # Here no collision is still True due to inversion
130
+ ret = ret | upd
131
+
132
+ return ~ret # Do the inversion
133
+
134
+
135
+ ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
136
+
137
+
138
+ def scale_tensor(
139
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
140
+ ):
141
+ if inp_scale is None:
142
+ inp_scale = (0, 1)
143
+ if tgt_scale is None:
144
+ tgt_scale = (0, 1)
145
+ if isinstance(tgt_scale, Tensor):
146
+ assert dat.shape[-1] == tgt_scale.shape[-1]
147
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
148
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
149
+ return dat
150
+
151
+
152
+ def dilate_fill(img, mask, iterations=10):
153
+ oldMask = mask.float()
154
+ oldImg = img
155
+
156
+ mask_kernel = torch.ones(
157
+ (1, 1, 3, 3),
158
+ dtype=oldMask.dtype,
159
+ device=oldMask.device,
160
+ )
161
+
162
+ for i in range(iterations):
163
+ newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
164
+
165
+ # Fill the extension with mean color of old valid regions
166
+ img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
167
+ mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
168
+ new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
169
+
170
+ # Average color of the valid region
171
+ mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
172
+ 2
173
+ )
174
+ # Extend it to the new region
175
+ fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
176
+
177
+ mask_conv = F.conv2d(
178
+ newMask, mask_kernel, padding=1
179
+ ) # Get the sum for each kernel patch
180
+ newImg = F.fold(
181
+ fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
182
+ ) / mask_conv.clamp(1)
183
+
184
+ diffMask = newMask - oldMask
185
+
186
+ oldMask = newMask
187
+ oldImg = torch.lerp(oldImg, newImg, diffMask)
188
+
189
+ return oldImg
190
+
191
+
192
+ def float32_to_uint8_np(
193
+ x: Float[np.ndarray, "*B H W C"],
194
+ dither: bool = True,
195
+ dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
196
+ dither_strength: float = 1.0,
197
+ ) -> Int[np.ndarray, "*B H W C"]:
198
+ if dither:
199
+ dither = (
200
+ dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
201
+ )
202
+ if dither_mask is not None:
203
+ dither = dither * dither_mask
204
+ return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
205
+ return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
206
+
207
+
208
+ def convert_data(data):
209
+ if data is None:
210
+ return None
211
+ elif isinstance(data, np.ndarray):
212
+ return data
213
+ elif isinstance(data, torch.Tensor):
214
+ if data.dtype in [torch.float16, torch.bfloat16]:
215
+ data = data.float()
216
+ return data.detach().cpu().numpy()
217
+ elif isinstance(data, list):
218
+ return [convert_data(d) for d in data]
219
+ elif isinstance(data, dict):
220
+ return {k: convert_data(v) for k, v in data.items()}
221
+ else:
222
+ raise TypeError(
223
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
224
+ type(data),
225
+ )
226
+
227
+
228
+ class ImageProcessor:
229
+ def convert_and_resize(
230
+ self,
231
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
232
+ size: int,
233
+ ):
234
+ if isinstance(image, PIL.Image.Image):
235
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
236
+ elif isinstance(image, np.ndarray):
237
+ if image.dtype == np.uint8:
238
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
239
+ else:
240
+ image = torch.from_numpy(image)
241
+ elif isinstance(image, torch.Tensor):
242
+ pass
243
+
244
+ batched = image.ndim == 4
245
+
246
+ if not batched:
247
+ image = image[None, ...]
248
+ image = F.interpolate(
249
+ image.permute(0, 3, 1, 2),
250
+ (size, size),
251
+ mode="bilinear",
252
+ align_corners=False,
253
+ antialias=True,
254
+ ).permute(0, 2, 3, 1)
255
+ if not batched:
256
+ image = image[0]
257
+ return image
258
+
259
+ def __call__(
260
+ self,
261
+ image: Union[
262
+ PIL.Image.Image,
263
+ np.ndarray,
264
+ torch.FloatTensor,
265
+ List[PIL.Image.Image],
266
+ List[np.ndarray],
267
+ List[torch.FloatTensor],
268
+ ],
269
+ size: int,
270
+ ) -> Any:
271
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
272
+ image = self.convert_and_resize(image, size)
273
+ else:
274
+ if not isinstance(image, list):
275
+ image = [image]
276
+ image = [self.convert_and_resize(im, size) for im in image]
277
+ image = torch.stack(image, dim=0)
278
+ return image
279
+
280
+
281
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
282
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
283
+ intrinsic = np.identity(3, dtype=np.float32)
284
+ intrinsic[0, 0] = focal_length
285
+ intrinsic[1, 1] = focal_length
286
+ intrinsic[0, 2] = W / 2.0
287
+ intrinsic[1, 2] = H / 2.0
288
+
289
+ if bs > 0:
290
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
291
+
292
+ return torch.from_numpy(intrinsic)