aiqtech commited on
Commit
317ed10
1 Parent(s): 45e40db

Upload 5 files

Browse files
sf3d/sf3d_box_uv_unwrap.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from sf3d.models.utils import dot, triangle_intersection_2d
10
+
11
+
12
+ def _box_assign_vertex_to_cube_face(
13
+ vertex_positions: Float[Tensor, "Nv 3"],
14
+ vertex_normals: Float[Tensor, "Nv 3"],
15
+ triangle_idxs: Integer[Tensor, "Nf 3"],
16
+ bbox: Float[Tensor, "2 3"],
17
+ ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
18
+ # Test to not have a scaled model to fit the space better
19
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
20
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
21
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
22
+
23
+ # Create a [0, 1] normalized vertex position
24
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
25
+ # And to [-1, 1]
26
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
27
+
28
+ # Get all vertex positions for each triangle
29
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
30
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
31
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
32
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
33
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
34
+
35
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
36
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
37
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
38
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
39
+
40
+ # Just average the normals per face
41
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
42
+
43
+ # Now decide based on the face normal in which box map we project
44
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
45
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
46
+
47
+ axis = torch.tensor(
48
+ [
49
+ [1, 0, 0], # 0
50
+ [-1, 0, 0], # 1
51
+ [0, 1, 0], # 2
52
+ [0, -1, 0], # 3
53
+ [0, 0, 1], # 4
54
+ [0, 0, -1], # 5
55
+ ],
56
+ device=face_normal.device,
57
+ dtype=face_normal.dtype,
58
+ )
59
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
60
+ index = face_normal_axis.argmax(-1)
61
+
62
+ max_axis, uc, vc = (
63
+ torch.ones_like(abs_x),
64
+ torch.zeros_like(tri_stack[..., :1]),
65
+ torch.zeros_like(tri_stack[..., :1]),
66
+ )
67
+ mask_pos_x = index == 0
68
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
69
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
70
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
71
+
72
+ mask_neg_x = index == 1
73
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
74
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
75
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
76
+
77
+ mask_pos_y = index == 2
78
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
79
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
80
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
81
+
82
+ mask_neg_y = index == 3
83
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
84
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
85
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
86
+
87
+ mask_pos_z = index == 4
88
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
89
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
90
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
91
+
92
+ mask_neg_z = index == 5
93
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
94
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
95
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
96
+
97
+ # UC from [-1, 1] to [0, 1]
98
+ max_dim_div = max_axis.max(dim=0, keepdims=True).values
99
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
100
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
101
+
102
+ uv = torch.stack([uc, vc], dim=-1)
103
+
104
+ return uv, index
105
+
106
+
107
+ def _assign_faces_uv_to_atlas_index(
108
+ vertex_positions: Float[Tensor, "Nv 3"],
109
+ triangle_idxs: Integer[Tensor, "Nf 3"],
110
+ face_uv: Float[Tensor, "Nf 3 2"],
111
+ face_index: Integer[Tensor, "Nf 3"],
112
+ ) -> Integer[Tensor, "Nf"]: # noqa: F821
113
+ triangle_pos = vertex_positions[triangle_idxs]
114
+ # We need to do perform 3 overlap checks.
115
+ # The first set is placed in the upper two thirds of the UV atlas.
116
+ # Conceptually, this is the direct visible surfaces from the each cube side
117
+ # The second set is placed in the lower thirds and the left half of the UV atlas.
118
+ # This is the first set of occluded surfaces. They will also be saved in the projected fashion
119
+ # The third pass finds all non assigned faces. They will be placed in the bottom right half of
120
+ # the UV atlas in scattered fashion.
121
+ assign_idx = face_index.clone()
122
+ for overlap_step in range(3):
123
+ overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
124
+ for i in range(overlap_step * 6, (overlap_step + 1) * 6):
125
+ mask = assign_idx == i
126
+ if not mask.any():
127
+ continue
128
+ # Get all elements belonging to the projection face
129
+ uv_triangle = face_uv[mask]
130
+ cur_triangle_pos = triangle_pos[mask]
131
+ # Find the center of the uv coordinates
132
+ center_uv = uv_triangle.mean(dim=1, keepdim=True)
133
+ # And also the radius of the triangle
134
+ uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
135
+
136
+ potentially_overlapping_mask = (
137
+ # Find all close triangles
138
+ (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
139
+ # Do not select the same element by offseting with an large valued identity matrix
140
+ + torch.eye(
141
+ uv_triangle.shape[0],
142
+ device=uv_triangle.device,
143
+ dtype=uv_triangle.dtype,
144
+ ).unsqueeze(-1)
145
+ * 1000
146
+ )
147
+ # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
148
+ potentially_overlapping_mask = (
149
+ potentially_overlapping_mask
150
+ <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
151
+ ).squeeze(-1)
152
+ overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
153
+
154
+ # Only unique triangles (A|B and B|A should be the same)
155
+ f = torch.min(overlap_coords, dim=-1).values
156
+ s = torch.max(overlap_coords, dim=-1).values
157
+ overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
158
+ first, second = overlap_coords.unbind(-1)
159
+
160
+ # Get the triangles
161
+ tri_1 = uv_triangle[first]
162
+ tri_2 = uv_triangle[second]
163
+
164
+ # Perform the actual set with the reduced number of potentially overlapping triangles
165
+ its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
166
+
167
+ # So we now need to detect which triangles are the occluded ones.
168
+ # We always assume the first to be the visible one (the others should move)
169
+ # In the previous step we use a lexigraphical sort to get the unique pairs
170
+ # In this we use a sort based on the orthographic projection
171
+ ax = 0 if i < 2 else 1 if i < 4 else 2
172
+ use_max = i % 2 == 1
173
+
174
+ tri1_c = cur_triangle_pos[first].mean(dim=1)
175
+ tri2_c = cur_triangle_pos[second].mean(dim=1)
176
+
177
+ mark_first = (
178
+ (tri1_c[..., ax] > tri2_c[..., ax])
179
+ if use_max
180
+ else (tri1_c[..., ax] < tri2_c[..., ax])
181
+ )
182
+ first[mark_first] = second[mark_first]
183
+
184
+ # Lastly the same index can be tested multiple times.
185
+ # If one marks it as overlapping we keep it marked as such.
186
+ # We do this by testing if it has been marked at least once.
187
+ unique_idx, rev_idx = torch.unique(first, return_inverse=True)
188
+
189
+ add = torch.zeros_like(unique_idx, dtype=torch.float32)
190
+ add.index_add_(0, rev_idx, its.float())
191
+ its_mask = add > 0
192
+
193
+ # And fill it in the overlapping indicator
194
+ idx = torch.where(mask)[0][unique_idx]
195
+ overlapping_indicator[idx] = its_mask
196
+
197
+ # Move the index to the overlap regions (shift by 6)
198
+ assign_idx[overlapping_indicator] += 6
199
+
200
+ # We do not care about the correct face placement after the first 2 slices
201
+ max_idx = 6 * 2
202
+ return assign_idx.clamp(0, max_idx)
203
+
204
+
205
+ def _find_slice_offset_and_scale(
206
+ index: Integer[Tensor, "Nf"], # noqa: F821
207
+ ) -> Tuple[
208
+ Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
209
+ ]: # noqa: F821
210
+ # 6 due to the 6 cube faces
211
+ off = 1 / 3
212
+ dupl_off = 1 / 6
213
+
214
+ # Here, we need to decide how to pack the textures in the case of overlap
215
+ def x_offset_calc(x, i):
216
+ offset_calc = i // 6
217
+ # Initial coordinates - just 3x2 grid
218
+ if offset_calc == 0:
219
+ return off * x
220
+ else:
221
+ # Smaller 3x2 grid plus eventual shift to right for
222
+ # second overlap
223
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
224
+
225
+ def y_offset_calc(x, i):
226
+ offset_calc = i // 6
227
+ # Initial coordinates - just a 3x2 grid
228
+ if offset_calc == 0:
229
+ return off * x
230
+ else:
231
+ # Smaller coordinates in the lowest row
232
+ return dupl_off * x + off * 2
233
+
234
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
235
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
236
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
237
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
238
+ for i in range(index.max().item() + 1):
239
+ mask = index == i
240
+ if not mask.any():
241
+ continue
242
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
243
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
244
+
245
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
246
+ # All overlap elements are saved in half scale
247
+ div_x[index >= 6] = 6
248
+ div_y = div_x.clone() # Same for y
249
+ # Except for the random overlaps
250
+ div_x[index >= 12] = 2
251
+ # But the random overlaps are saved in a large block in the lower thirds
252
+ div_y[index >= 12] = 3
253
+
254
+ return offset_x, offset_y, div_x, div_y
255
+
256
+
257
+ def rotation_flip_matrix_2d(
258
+ rad: float, flip_x: bool = False, flip_y: bool = False
259
+ ) -> Float[Tensor, "2 2"]:
260
+ cos = math.cos(rad)
261
+ sin = math.sin(rad)
262
+ rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
263
+ flip_mat = torch.tensor(
264
+ [
265
+ [-1 if flip_x else 1, 0],
266
+ [0, -1 if flip_y else 1],
267
+ ],
268
+ dtype=torch.float32,
269
+ )
270
+
271
+ return flip_mat @ rot_mat
272
+
273
+
274
+ def calculate_tangents(
275
+ vertex_positions: Float[Tensor, "Nv 3"],
276
+ vertex_normals: Float[Tensor, "Nv 3"],
277
+ triangle_idxs: Integer[Tensor, "Nf 3"],
278
+ face_uv: Float[Tensor, "Nf 3 2"],
279
+ ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
280
+ vn_idx = [None] * 3
281
+ pos = [None] * 3
282
+ tex = face_uv.unbind(1)
283
+ for i in range(0, 3):
284
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
285
+ # t_nrm_idx is always the same as t_pos_idx
286
+ vn_idx[i] = triangle_idxs[:, i]
287
+
288
+ tangents = torch.zeros_like(vertex_normals)
289
+ tansum = torch.zeros_like(vertex_normals)
290
+
291
+ # Compute tangent space for each triangle
292
+ duv1 = tex[1] - tex[0]
293
+ duv2 = tex[2] - tex[0]
294
+ dpos1 = pos[1] - pos[0]
295
+ dpos2 = pos[2] - pos[0]
296
+
297
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
298
+
299
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
300
+
301
+ # Avoid division by zero for degenerated texture coordinates
302
+ denom_safe = denom.clip(1e-6)
303
+ tang = tng_nom / denom_safe
304
+
305
+ # Update all 3 vertices
306
+ for i in range(0, 3):
307
+ idx = vn_idx[i][:, None].repeat(1, 3)
308
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
309
+ tansum.scatter_add_(
310
+ 0, idx, torch.ones_like(tang)
311
+ ) # tansum[n_i] = tansum[n_i] + 1
312
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
313
+ # triangles influence the tangent space more
314
+ tangents = tangents / tansum
315
+
316
+ # Normalize and make sure tangent is perpendicular to normal
317
+ tangents = F.normalize(tangents, dim=1)
318
+ tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
319
+
320
+ return tangents
321
+
322
+
323
+ def _rotate_uv_slices_consistent_space(
324
+ vertex_positions: Float[Tensor, "Nv 3"],
325
+ vertex_normals: Float[Tensor, "Nv 3"],
326
+ triangle_idxs: Integer[Tensor, "Nf 3"],
327
+ uv: Float[Tensor, "Nf 3 2"],
328
+ index: Integer[Tensor, "Nf"], # noqa: F821
329
+ ):
330
+ tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
331
+ pos_stack = torch.stack(
332
+ [
333
+ -vertex_positions[..., 1],
334
+ vertex_positions[..., 0],
335
+ torch.zeros_like(vertex_positions[..., 0]),
336
+ ],
337
+ dim=-1,
338
+ )
339
+ expected_tangents = F.normalize(
340
+ torch.linalg.cross(
341
+ vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
342
+ ),
343
+ -1,
344
+ )
345
+
346
+ actual_tangents = tangents[triangle_idxs]
347
+ expected_tangents = expected_tangents[triangle_idxs]
348
+
349
+ def rotation_matrix_2d(theta):
350
+ c, s = torch.cos(theta), torch.sin(theta)
351
+ return torch.tensor([[c, -s], [s, c]])
352
+
353
+ # Now find the rotation
354
+ index_mod = index % 6 # Shouldn't happen. Just for safety
355
+ for i in range(6):
356
+ mask = index_mod == i
357
+ if not mask.any():
358
+ continue
359
+
360
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
361
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
362
+
363
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
364
+ cross_product = (
365
+ actual_mean_tangent[0] * expected_mean_tangent[1]
366
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
367
+ )
368
+ angle = torch.atan2(cross_product, dot_product)
369
+
370
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
371
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
372
+ uv_cur = uv[mask] * 2 - 1 # Center it first
373
+ # Rotate it
374
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
375
+
376
+ # Rescale uv[mask] to be within the 0-1 range
377
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
378
+
379
+ return uv
380
+
381
+
382
+ def _handle_slice_uvs(
383
+ uv: Float[Tensor, "Nf 3 2"],
384
+ index: Integer[Tensor, "Nf"], # noqa: F821
385
+ island_padding: float,
386
+ max_index: int = 6 * 2,
387
+ ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
388
+ uc, vc = uv.unbind(-1)
389
+
390
+ # Get the second slice (The first overlap)
391
+ index_filter = [index == i for i in range(6, max_index)]
392
+
393
+ # Normalize them to always fully fill the atlas patch
394
+ for i, fi in enumerate(index_filter):
395
+ if fi.sum() > 0:
396
+ # Scale the slice but only up to a factor of 2
397
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
398
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
399
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
400
+
401
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
402
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
403
+
404
+ return torch.stack([uc_padded, vc_padded], dim=-1)
405
+
406
+
407
+ def _handle_remaining_uvs(
408
+ uv: Float[Tensor, "Nf 3 2"],
409
+ index: Integer[Tensor, "Nf"], # noqa: F821
410
+ island_padding: float,
411
+ ) -> Float[Tensor, "Nf 3 2"]:
412
+ uc, vc = uv.unbind(-1)
413
+ # Get all remaining elements
414
+ remaining_filter = index >= 6 * 2
415
+ squares_left = remaining_filter.sum()
416
+
417
+ if squares_left == 0:
418
+ return uv
419
+
420
+ uc = uc[remaining_filter]
421
+ vc = vc[remaining_filter]
422
+
423
+ # Or remaining triangles are distributed in a rectangle
424
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
425
+ ratio = 0.5 * (1 / 3) # 1.5
426
+ # sqrt(744/(0.5*(1/3)))
427
+
428
+ mult = math.sqrt(squares_left / ratio)
429
+ num_square_width = int(math.ceil(0.5 * mult))
430
+ num_square_height = int(math.ceil(squares_left / num_square_width))
431
+
432
+ width = 1 / num_square_width
433
+ height = 1 / num_square_height
434
+
435
+ # The idea is again to keep the texture resolution consistent with the first slice
436
+ # This only occupys half the region in the texture chart but the scaling on the squares
437
+ # assumes full coverage.
438
+ clip_val = min(width, height) * 1.5
439
+ # Now normalize the UVs with taking into account the maximum scaling
440
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
441
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
442
+ ).clip(clip_val)
443
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
444
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
445
+ ).clip(clip_val)
446
+ # Add a small padding
447
+ uc = (
448
+ uc * (1 - island_padding * num_square_width * 0.5)
449
+ + island_padding * num_square_width * 0.25
450
+ ).clip(0, 1)
451
+ vc = (
452
+ vc * (1 - island_padding * num_square_height * 0.5)
453
+ + island_padding * num_square_height * 0.25
454
+ ).clip(0, 1)
455
+
456
+ uc = uc * width
457
+ vc = vc * height
458
+
459
+ # And calculate offsets for each element
460
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
461
+ x_idx = idx % num_square_width
462
+ y_idx = idx // num_square_width
463
+ # And move each triangle to its own spot
464
+ uc = uc + x_idx[:, None] * width
465
+ vc = vc + y_idx[:, None] * height
466
+
467
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
468
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
469
+
470
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
471
+
472
+ return uv
473
+
474
+
475
+ def _distribute_individual_uvs_in_atlas(
476
+ face_uv: Float[Tensor, "Nf 3 2"],
477
+ assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
478
+ offset_x: Float[Tensor, "Nf"], # noqa: F821
479
+ offset_y: Float[Tensor, "Nf"], # noqa: F821
480
+ div_x: Float[Tensor, "Nf"], # noqa: F821
481
+ div_y: Float[Tensor, "Nf"], # noqa: F821
482
+ island_padding: float,
483
+ ):
484
+ # Place the slice first
485
+ placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
486
+ # Then handle the remaining overlap elements
487
+ placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
488
+
489
+ uc, vc = placed_uv.unbind(-1)
490
+ uc = uc / div_x[:, None] + offset_x[:, None]
491
+ vc = vc / div_y[:, None] + offset_y[:, None]
492
+
493
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
494
+
495
+ return uv
496
+
497
+
498
+ def _get_unique_face_uv(
499
+ uv: Float[Tensor, "Nf 3 2"],
500
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
501
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
502
+ # And add the face to uv index mapping
503
+ vtex_idx = unique_idx.view(-1, 3)
504
+
505
+ return unique_uv, vtex_idx
506
+
507
+
508
+ def _align_mesh_with_main_axis(
509
+ vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
510
+ ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
511
+ # Use pca to find the 2 main axis (third is derived by cross product)
512
+ # Set the random seed so it's repeatable
513
+ torch.manual_seed(0)
514
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
515
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
516
+
517
+ main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
518
+ # Orthogonalize the second axis
519
+ seconday_axis: Float[Tensor, "3"] = F.normalize(
520
+ seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
521
+ )
522
+ # Create perpendicular third axis
523
+ third_axis: Float[Tensor, "3"] = F.normalize(
524
+ torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
525
+ )
526
+
527
+ # Check to which canonical axis each aligns
528
+ main_axis_max_idx = main_axis.abs().argmax().item()
529
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
530
+ third_axis_max_idx = third_axis.abs().argmax().item()
531
+
532
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
533
+ # If two axes have the same argmax move one of them
534
+ all_possible_axis = {0, 1, 2}
535
+ cur_index = 1
536
+ while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
537
+ # Find missing axis
538
+ missing_axis = all_possible_axis - set(
539
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
540
+ )
541
+ missing_axis = missing_axis.pop()
542
+ # Just assign it to third axis as it had the smallest contribution to the
543
+ # overall shape
544
+ if cur_index == 1:
545
+ third_axis_max_idx = missing_axis
546
+ elif cur_index == 2:
547
+ seconday_axis_max_idx = missing_axis
548
+ else:
549
+ raise ValueError("Could not find 3 unique axis")
550
+ cur_index += 1
551
+
552
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
553
+ raise ValueError("Could not find 3 unique axis")
554
+
555
+ axes = [None] * 3
556
+ axes[main_axis_max_idx] = main_axis
557
+ axes[seconday_axis_max_idx] = seconday_axis
558
+ axes[third_axis_max_idx] = third_axis
559
+ # Create rotation matrix from the individual axes
560
+ rot_mat = torch.stack(axes, dim=1).T
561
+
562
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
563
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
564
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
565
+
566
+ return vertex_positions, vertex_normals
567
+
568
+
569
+ def box_projection_uv_unwrap(
570
+ vertex_positions: Float[Tensor, "Nv 3"],
571
+ vertex_normals: Float[Tensor, "Nv 3"],
572
+ triangle_idxs: Integer[Tensor, "Nf 3"],
573
+ island_padding: float,
574
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
575
+ # Align the mesh with main axis directions first
576
+ vertex_positions, vertex_normals = _align_mesh_with_main_axis(
577
+ vertex_positions, vertex_normals
578
+ )
579
+
580
+ bbox: Float[Tensor, "2 3"] = torch.stack(
581
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
582
+ )
583
+ # First decide in which cube face the triangle is placed
584
+ face_uv, face_index = _box_assign_vertex_to_cube_face(
585
+ vertex_positions, vertex_normals, triangle_idxs, bbox
586
+ )
587
+
588
+ # Rotate the UV islands in a way that they align with the radial z tangent space
589
+ face_uv = _rotate_uv_slices_consistent_space(
590
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
591
+ )
592
+
593
+ # Then find where where the face is placed in the atlas.
594
+ # This has to detect potential overlaps
595
+ assigned_atlas_index = _assign_faces_uv_to_atlas_index(
596
+ vertex_positions, triangle_idxs, face_uv, face_index
597
+ )
598
+
599
+ # Then figure out the final place in the atlas based on the assignment
600
+ offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
601
+ assigned_atlas_index
602
+ )
603
+
604
+ # Next distribute the faces in the uv atlas
605
+ placed_uv = _distribute_individual_uvs_in_atlas(
606
+ face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
607
+ )
608
+
609
+ # And get the unique per-triangle UV coordinates
610
+ return _get_unique_face_uv(placed_uv)
sf3d/sf3d_system.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import trimesh
9
+ from einops import rearrange
10
+ from huggingface_hub import hf_hub_download
11
+ from jaxtyping import Float
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from safetensors.torch import load_model
15
+ from torch import Tensor
16
+
17
+ from sf3d.models.isosurface import MarchingTetrahedraHelper
18
+ from sf3d.models.mesh import Mesh
19
+ from sf3d.models.utils import (
20
+ BaseModule,
21
+ ImageProcessor,
22
+ convert_data,
23
+ dilate_fill,
24
+ dot,
25
+ find_class,
26
+ float32_to_uint8_np,
27
+ normalize,
28
+ scale_tensor,
29
+ )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
31
+
32
+ from .texture_baker import TextureBaker
33
+
34
+
35
+ class SF3D(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ cond_image_size: int
39
+ isosurface_resolution: int
40
+ isosurface_threshold: float = 10.0
41
+ radius: float = 1.0
42
+ background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
43
+ default_fovy_deg: float = 40.0
44
+ default_distance: float = 1.6
45
+
46
+ camera_embedder_cls: str = ""
47
+ camera_embedder: dict = field(default_factory=dict)
48
+
49
+ image_tokenizer_cls: str = ""
50
+ image_tokenizer: dict = field(default_factory=dict)
51
+
52
+ tokenizer_cls: str = ""
53
+ tokenizer: dict = field(default_factory=dict)
54
+
55
+ backbone_cls: str = ""
56
+ backbone: dict = field(default_factory=dict)
57
+
58
+ post_processor_cls: str = ""
59
+ post_processor: dict = field(default_factory=dict)
60
+
61
+ decoder_cls: str = ""
62
+ decoder: dict = field(default_factory=dict)
63
+
64
+ image_estimator_cls: str = ""
65
+ image_estimator: dict = field(default_factory=dict)
66
+
67
+ global_estimator_cls: str = ""
68
+ global_estimator: dict = field(default_factory=dict)
69
+
70
+ cfg: Config
71
+
72
+ @classmethod
73
+ def from_pretrained(
74
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
75
+ ):
76
+ if os.path.isdir(pretrained_model_name_or_path):
77
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
78
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
79
+ else:
80
+ config_path = hf_hub_download(
81
+ repo_id=pretrained_model_name_or_path, filename=config_name
82
+ )
83
+ weight_path = hf_hub_download(
84
+ repo_id=pretrained_model_name_or_path, filename=weight_name
85
+ )
86
+
87
+ cfg = OmegaConf.load(config_path)
88
+ OmegaConf.resolve(cfg)
89
+ model = cls(cfg)
90
+ load_model(model, weight_path)
91
+ return model
92
+
93
+ @property
94
+ def device(self):
95
+ return next(self.parameters()).device
96
+
97
+ def configure(self):
98
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
99
+ self.cfg.image_tokenizer
100
+ )
101
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
102
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
103
+ self.cfg.camera_embedder
104
+ )
105
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
106
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
107
+ self.cfg.post_processor
108
+ )
109
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
110
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
111
+ self.cfg.image_estimator
112
+ )
113
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
114
+ self.cfg.global_estimator
115
+ )
116
+
117
+ self.bbox: Float[Tensor, "2 3"]
118
+ self.register_buffer(
119
+ "bbox",
120
+ torch.as_tensor(
121
+ [
122
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
123
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
124
+ ],
125
+ dtype=torch.float32,
126
+ ),
127
+ )
128
+ self.isosurface_helper = MarchingTetrahedraHelper(
129
+ self.cfg.isosurface_resolution,
130
+ os.path.join(
131
+ os.path.dirname(__file__),
132
+ "..",
133
+ "load",
134
+ "tets",
135
+ f"{self.cfg.isosurface_resolution}_tets.npz",
136
+ ),
137
+ )
138
+
139
+ self.baker = TextureBaker()
140
+ self.image_processor = ImageProcessor()
141
+
142
+ def triplane_to_meshes(
143
+ self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
144
+ ) -> list[Mesh]:
145
+ meshes = []
146
+ for i in range(triplanes.shape[0]):
147
+ triplane = triplanes[i]
148
+ grid_vertices = scale_tensor(
149
+ self.isosurface_helper.grid_vertices.to(triplanes.device),
150
+ self.isosurface_helper.points_range,
151
+ self.bbox,
152
+ )
153
+
154
+ values = self.query_triplane(grid_vertices, triplane)
155
+ decoded = self.decoder(values, include=["vertex_offset", "density"])
156
+ sdf = decoded["density"] - self.cfg.isosurface_threshold
157
+
158
+ deform = decoded["vertex_offset"].squeeze(0)
159
+
160
+ mesh: Mesh = self.isosurface_helper(
161
+ sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
162
+ )
163
+ mesh.v_pos = scale_tensor(
164
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
165
+ )
166
+
167
+ meshes.append(mesh)
168
+
169
+ return meshes
170
+
171
+ def query_triplane(
172
+ self,
173
+ positions: Float[Tensor, "*B N 3"],
174
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
175
+ ) -> Float[Tensor, "*B N F"]:
176
+ batched = positions.ndim == 3
177
+ if not batched:
178
+ # no batch dimension
179
+ triplanes = triplanes[None, ...]
180
+ positions = positions[None, ...]
181
+ assert triplanes.ndim == 5 and positions.ndim == 3
182
+
183
+ positions = scale_tensor(
184
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
185
+ )
186
+
187
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
188
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
189
+ dim=-3,
190
+ ).to(triplanes.dtype)
191
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
192
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
193
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
194
+ align_corners=True,
195
+ mode="bilinear",
196
+ )
197
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
198
+
199
+ return out
200
+
201
+ def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
202
+ # if batch[rgb_cond] is only one view, add a view dimension
203
+ if len(batch["rgb_cond"].shape) == 4:
204
+ batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
205
+ batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
206
+ batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
207
+ batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
208
+ batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
209
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
210
+
211
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
212
+ camera_embeds = self.camera_embedder(**batch)
213
+
214
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
215
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
216
+ modulation_cond=camera_embeds,
217
+ )
218
+
219
+ input_image_tokens = rearrange(
220
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
221
+ )
222
+
223
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
224
+
225
+ tokens = self.backbone(
226
+ tokens,
227
+ encoder_hidden_states=input_image_tokens,
228
+ modulation_cond=None,
229
+ )
230
+
231
+ direct_codes = self.tokenizer.detokenize(tokens)
232
+ scene_codes = self.post_processor(direct_codes)
233
+ return scene_codes, direct_codes
234
+
235
+ def run_image(
236
+ self,
237
+ image: Image,
238
+ bake_resolution: int,
239
+ estimate_illumination: bool = False,
240
+ ) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
241
+ if image.mode != "RGBA":
242
+ raise ValueError("Image must be in RGBA mode")
243
+ img_cond = (
244
+ torch.from_numpy(
245
+ np.asarray(
246
+ image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
247
+ ).astype(np.float32)
248
+ / 255.0
249
+ )
250
+ .float()
251
+ .clip(0, 1)
252
+ .to(self.device)
253
+ )
254
+ mask_cond = img_cond[:, :, -1:]
255
+ rgb_cond = torch.lerp(
256
+ torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
257
+ img_cond[:, :, :3],
258
+ mask_cond,
259
+ )
260
+
261
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
262
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
263
+ self.cfg.default_fovy_deg,
264
+ self.cfg.cond_image_size,
265
+ self.cfg.cond_image_size,
266
+ )
267
+
268
+ batch = {
269
+ "rgb_cond": rgb_cond,
270
+ "mask_cond": mask_cond,
271
+ "c2w_cond": c2w_cond.unsqueeze(0),
272
+ "intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
273
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
274
+ }
275
+
276
+ meshes, global_dict = self.generate_mesh(
277
+ batch, bake_resolution, estimate_illumination
278
+ )
279
+ return meshes[0], global_dict
280
+
281
+ def generate_mesh(
282
+ self,
283
+ batch,
284
+ bake_resolution: int,
285
+ estimate_illumination: bool = False,
286
+ ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
287
+ batch["rgb_cond"] = self.image_processor(
288
+ batch["rgb_cond"], self.cfg.cond_image_size
289
+ )
290
+ batch["mask_cond"] = self.image_processor(
291
+ batch["mask_cond"], self.cfg.cond_image_size
292
+ )
293
+ scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
294
+
295
+ global_dict = {}
296
+ if self.image_estimator is not None:
297
+ global_dict.update(
298
+ self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
299
+ )
300
+ if self.global_estimator is not None and estimate_illumination:
301
+ global_dict.update(self.global_estimator(non_postprocessed_codes))
302
+
303
+ with torch.no_grad():
304
+ with torch.autocast(device_type="cuda", enabled=False):
305
+ meshes = self.triplane_to_meshes(scene_codes)
306
+
307
+ rets = []
308
+ for i, mesh in enumerate(meshes):
309
+ # Check for empty mesh
310
+ if mesh.v_pos.shape[0] == 0:
311
+ rets.append(trimesh.Trimesh())
312
+ continue
313
+
314
+ mesh.unwrap_uv()
315
+
316
+ # Build textures
317
+ rast = self.baker.rasterize(
318
+ mesh.v_tex, mesh.t_pos_idx, bake_resolution
319
+ )
320
+ bake_mask = self.baker.get_mask(rast)
321
+
322
+ pos_bake = self.baker.interpolate(
323
+ mesh.v_pos,
324
+ rast,
325
+ mesh.t_pos_idx,
326
+ mesh.v_tex,
327
+ )
328
+ gb_pos = pos_bake[bake_mask]
329
+
330
+ tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
331
+ decoded = self.decoder(
332
+ tri_query, exclude=["density", "vertex_offset"]
333
+ )
334
+
335
+ nrm = self.baker.interpolate(
336
+ mesh.v_nrm,
337
+ rast,
338
+ mesh.t_pos_idx,
339
+ mesh.v_tex,
340
+ )
341
+ gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
342
+ decoded["normal"] = gb_nrm
343
+
344
+ # Check if any keys in global_dict start with decoded_
345
+ for k, v in global_dict.items():
346
+ if k.startswith("decoder_"):
347
+ decoded[k.replace("decoder_", "")] = v[i]
348
+
349
+ mat_out = {
350
+ "albedo": decoded["features"],
351
+ "roughness": decoded["roughness"],
352
+ "metallic": decoded["metallic"],
353
+ "normal": normalize(decoded["perturb_normal"]),
354
+ "bump": None,
355
+ }
356
+
357
+ for k, v in mat_out.items():
358
+ if v is None:
359
+ continue
360
+ if v.shape[0] == 1:
361
+ # Skip and directly add a single value
362
+ mat_out[k] = v[0]
363
+ else:
364
+ f = torch.zeros(
365
+ bake_resolution,
366
+ bake_resolution,
367
+ v.shape[-1],
368
+ dtype=v.dtype,
369
+ device=v.device,
370
+ )
371
+ if v.shape == f.shape:
372
+ continue
373
+ if k == "normal":
374
+ # Use un-normalized tangents here so that larger smaller tris
375
+ # Don't effect the tangents that much
376
+ tng = self.baker.interpolate(
377
+ mesh.v_tng,
378
+ rast,
379
+ mesh.t_pos_idx,
380
+ mesh.v_tex,
381
+ )
382
+ gb_tng = tng[bake_mask]
383
+ gb_tng = F.normalize(gb_tng, dim=-1)
384
+ gb_btng = F.normalize(
385
+ torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
386
+ )
387
+ normal = F.normalize(mat_out["normal"], dim=-1)
388
+
389
+ bump = torch.cat(
390
+ # Check if we have to flip some things
391
+ (
392
+ dot(normal, gb_tng),
393
+ dot(normal, gb_btng),
394
+ dot(normal, gb_nrm).clip(
395
+ 0.3, 1
396
+ ), # Never go below 0.3. This would indicate a flipped (or close to one) normal
397
+ ),
398
+ -1,
399
+ )
400
+ bump = (bump * 0.5 + 0.5).clamp(0, 1)
401
+
402
+ f[bake_mask] = bump.view(-1, 3)
403
+ mat_out["bump"] = f
404
+ else:
405
+ f[bake_mask] = v.view(-1, v.shape[-1])
406
+ mat_out[k] = f
407
+
408
+ def uv_padding(arr):
409
+ if arr.ndim == 1:
410
+ return arr
411
+ return (
412
+ dilate_fill(
413
+ arr.permute(2, 0, 1)[None, ...],
414
+ bake_mask.unsqueeze(0).unsqueeze(0),
415
+ iterations=bake_resolution // 150,
416
+ )
417
+ .squeeze(0)
418
+ .permute(1, 2, 0)
419
+ )
420
+
421
+ verts_np = convert_data(mesh.v_pos)
422
+ faces = convert_data(mesh.t_pos_idx)
423
+ uvs = convert_data(mesh.v_tex)
424
+
425
+ basecolor_tex = Image.fromarray(
426
+ float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
427
+ ).convert("RGB")
428
+ basecolor_tex.format = "JPEG"
429
+
430
+ metallic = mat_out["metallic"].squeeze().cpu().item()
431
+ roughness = mat_out["roughness"].squeeze().cpu().item()
432
+
433
+ if "bump" in mat_out and mat_out["bump"] is not None:
434
+ bump_np = convert_data(uv_padding(mat_out["bump"]))
435
+ bump_up = np.ones_like(bump_np)
436
+ bump_up[..., :2] = 0.5
437
+ bump_up[..., 2:] = 1
438
+ bump_tex = Image.fromarray(
439
+ float32_to_uint8_np(
440
+ bump_np,
441
+ dither=True,
442
+ # Do not dither if something is perfectly flat
443
+ dither_mask=np.all(
444
+ bump_np == bump_up, axis=-1, keepdims=True
445
+ ).astype(np.float32),
446
+ )
447
+ ).convert("RGB")
448
+ bump_tex.format = (
449
+ "JPEG" # PNG would be better but the assets are larger
450
+ )
451
+ else:
452
+ bump_tex = None
453
+
454
+ material = trimesh.visual.material.PBRMaterial(
455
+ baseColorTexture=basecolor_tex,
456
+ roughnessFactor=roughness,
457
+ metallicFactor=metallic,
458
+ normalTexture=bump_tex,
459
+ )
460
+
461
+ tmesh = trimesh.Trimesh(
462
+ vertices=verts_np,
463
+ faces=faces,
464
+ visual=trimesh.visual.texture.TextureVisuals(
465
+ uv=uvs, material=material
466
+ ),
467
+ )
468
+ rot = trimesh.transformations.rotation_matrix(
469
+ np.radians(-90), [1, 0, 0]
470
+ )
471
+ tmesh.apply_transform(rot)
472
+ tmesh.apply_transform(
473
+ trimesh.transformations.rotation_matrix(
474
+ np.radians(90), [0, 1, 0]
475
+ )
476
+ )
477
+
478
+ tmesh.invert()
479
+
480
+ rets.append(tmesh)
481
+
482
+ return rets, global_dict
sf3d/sf3d_texture_baker.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import slangtorch
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Bool, Float
7
+ from torch import Tensor
8
+
9
+
10
+ class TextureBaker(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.baker = slangtorch.loadModule(
14
+ os.path.join(os.path.dirname(__file__), "texture_baker.slang")
15
+ )
16
+
17
+ def rasterize(
18
+ self,
19
+ uv: Float[Tensor, "Nv 2"],
20
+ face_indices: Float[Tensor, "Nf 3"],
21
+ bake_resolution: int,
22
+ ) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
23
+ if not face_indices.is_cuda or not uv.is_cuda:
24
+ raise ValueError("All input tensors must be on cuda")
25
+
26
+ face_indices = face_indices.to(torch.int32)
27
+ uv = uv.to(torch.float32)
28
+
29
+ rast_result = torch.empty(
30
+ bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
31
+ )
32
+
33
+ block_size = 16
34
+ grid_size = bake_resolution // block_size
35
+ self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
36
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
37
+ )
38
+
39
+ return rast_result
40
+
41
+ def get_mask(
42
+ self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
43
+ ) -> Bool[Tensor, "bake_resolution bake_resolution"]:
44
+ return rast[..., -1] >= 0
45
+
46
+ def interpolate(
47
+ self,
48
+ attr: Float[Tensor, "Nv 3"],
49
+ rast: Float[Tensor, "bake_resolution bake_resolution 4"],
50
+ face_indices: Float[Tensor, "Nf 3"],
51
+ uv: Float[Tensor, "Nv 2"],
52
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
53
+ # Make sure all input tensors are on torch
54
+ if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
55
+ raise ValueError("All input tensors must be on cuda")
56
+
57
+ attr = attr.to(torch.float32)
58
+ face_indices = face_indices.to(torch.int32)
59
+ uv = uv.to(torch.float32)
60
+
61
+ pos_bake = torch.zeros(
62
+ rast.shape[0],
63
+ rast.shape[1],
64
+ 3,
65
+ device=attr.device,
66
+ dtype=attr.dtype,
67
+ )
68
+
69
+ block_size = 16
70
+ grid_size = rast.shape[0] // block_size
71
+ self.baker.interpolate(
72
+ attr=attr, indices=face_indices, rast=rast, output=pos_bake
73
+ ).launchRaw(
74
+ blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
75
+ )
76
+
77
+ return pos_bake
78
+
79
+ def forward(
80
+ self,
81
+ attr: Float[Tensor, "Nv 3"],
82
+ uv: Float[Tensor, "Nv 2"],
83
+ face_indices: Float[Tensor, "Nf 3"],
84
+ bake_resolution: int,
85
+ ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
86
+ rast = self.rasterize(uv, face_indices, bake_resolution)
87
+ return self.interpolate(attr, rast, face_indices, uv)
sf3d/sf3d_texture_baker.slang ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // xy: 2D test position
2
+ // v1: vertex position 1
3
+ // v2: vertex position 2
4
+ // v3: vertex position 3
5
+ //
6
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
7
+ {
8
+ // Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
9
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
10
+ float2 v1v2 = v2 - v1;
11
+ float2 v1v3 = v3 - v1;
12
+ float2 xyv1 = xy - v1;
13
+
14
+ float d00 = dot(v1v2, v1v2);
15
+ float d01 = dot(v1v2, v1v3);
16
+ float d11 = dot(v1v3, v1v3);
17
+ float d20 = dot(xyv1, v1v2);
18
+ float d21 = dot(xyv1, v1v3);
19
+
20
+ float denom = d00 * d11 - d01 * d01;
21
+ v = (d11 * d20 - d01 * d21) / denom;
22
+ w = (d00 * d21 - d01 * d20) / denom;
23
+ u = 1.0 - v - w;
24
+
25
+ return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
26
+ }
27
+
28
+ [AutoPyBindCUDA]
29
+ [CUDAKernel]
30
+ void interpolate(
31
+ TensorView<float3> attr,
32
+ TensorView<int3> indices,
33
+ TensorView<float4> rast,
34
+ TensorView<float3> output)
35
+ {
36
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
37
+
38
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
39
+
40
+ if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
41
+ return;
42
+
43
+ float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
44
+ int triangle_idx = int(barycentric.w);
45
+
46
+ if (triangle_idx < 0) {
47
+ output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
48
+ return;
49
+ }
50
+
51
+ float3 v1 = attr[indices[triangle_idx].x];
52
+ float3 v2 = attr[indices[triangle_idx].y];
53
+ float3 v3 = attr[indices[triangle_idx].z];
54
+
55
+ output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
56
+ }
57
+
58
+ [AutoPyBindCUDA]
59
+ [CUDAKernel]
60
+ void bake_uv(
61
+ TensorView<float2> uv,
62
+ TensorView<int3> indices,
63
+ TensorView<float4> output)
64
+ {
65
+ uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
66
+
67
+ if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
68
+ return;
69
+
70
+ // We index x,y but the orginal coords are HW. So swap them
71
+ float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
72
+ // Normalize to [0, 1]
73
+ pixel_coord /= float2(output.size(1), output.size(0));
74
+ pixel_coord = clamp(pixel_coord, 0.0, 1.0);
75
+ // Flip x-axis
76
+ pixel_coord.y = 1 - pixel_coord.y;
77
+
78
+ for (int i = 0; i < indices.size(0); i++) {
79
+ float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
80
+ float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
81
+ float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
82
+
83
+ float u, v, w;
84
+ bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
85
+
86
+ if (hit){
87
+ output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
88
+ return;
89
+ }
90
+ }
91
+
92
+ output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
93
+ }
sf3d/sf3d_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import rembg
5
+ import torch
6
+ from PIL import Image
7
+
8
+ import sf3d.models.utils as sf3d_utils
9
+
10
+
11
+ def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
12
+ intrinsic = sf3d_utils.get_intrinsic_from_fov(
13
+ np.deg2rad(fov_deg),
14
+ H=cond_height,
15
+ W=cond_width,
16
+ )
17
+ intrinsic_normed_cond = intrinsic.clone()
18
+ intrinsic_normed_cond[..., 0, 2] /= cond_width
19
+ intrinsic_normed_cond[..., 1, 2] /= cond_height
20
+ intrinsic_normed_cond[..., 0, 0] /= cond_width
21
+ intrinsic_normed_cond[..., 1, 1] /= cond_height
22
+
23
+ return intrinsic, intrinsic_normed_cond
24
+
25
+
26
+ def default_cond_c2w(distance: float):
27
+ c2w_cond = torch.as_tensor(
28
+ [
29
+ [0, 0, 1, distance],
30
+ [1, 0, 0, 0],
31
+ [0, 1, 0, 0],
32
+ [0, 0, 0, 1],
33
+ ]
34
+ ).float()
35
+ return c2w_cond
36
+
37
+
38
+ def remove_background(
39
+ image: Image,
40
+ rembg_session: Any = None,
41
+ force: bool = False,
42
+ **rembg_kwargs,
43
+ ) -> Image:
44
+ do_remove = True
45
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
46
+ do_remove = False
47
+ do_remove = do_remove or force
48
+ if do_remove:
49
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
50
+ return image
51
+
52
+
53
+ def resize_foreground(
54
+ image: Image,
55
+ ratio: float,
56
+ ) -> Image:
57
+ image = np.array(image)
58
+ assert image.shape[-1] == 4
59
+ alpha = np.where(image[..., 3] > 0)
60
+ y1, y2, x1, x2 = (
61
+ alpha[0].min(),
62
+ alpha[0].max(),
63
+ alpha[1].min(),
64
+ alpha[1].max(),
65
+ )
66
+ # crop the foreground
67
+ fg = image[y1:y2, x1:x2]
68
+ # pad to square
69
+ size = max(fg.shape[0], fg.shape[1])
70
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
71
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
72
+ new_image = np.pad(
73
+ fg,
74
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
75
+ mode="constant",
76
+ constant_values=((0, 0), (0, 0), (0, 0)),
77
+ )
78
+
79
+ # compute padding according to the ratio
80
+ new_size = int(new_image.shape[0] / ratio)
81
+ # pad to size, double side
82
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
83
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
84
+ new_image = np.pad(
85
+ new_image,
86
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
87
+ mode="constant",
88
+ constant_values=((0, 0), (0, 0), (0, 0)),
89
+ )
90
+ new_image = Image.fromarray(new_image, mode="RGBA")
91
+ return new_image