FrozenBurning commited on
Commit
cb029d0
·
1 Parent(s): e3760d0

Update fast uv unwrap

Browse files
Files changed (5) hide show
  1. app.py +15 -14
  2. configs/inference_dit.yml +1 -0
  3. inference.py +26 -25
  4. requirements.txt +2 -1
  5. utils/uv_unwrap.py +685 -0
app.py CHANGED
@@ -139,23 +139,23 @@ def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
139
  recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
140
  visualize_video_primvolume(config.output_dir, batch, recon_param, 15, rm, device)
141
  prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
142
- torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir))
143
 
144
- return output_rgb_video_path, output_prim_video_path, output_mat_video_path, gr.update(interactive=True)
145
-
146
- def export_mesh(remesh="No", mc_resolution=256, decimate=100000):
147
  # exporting GLB mesh
148
  output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
149
  if remesh == "No":
150
  config.inference.remesh = False
151
  elif remesh == "Yes":
152
  config.inference.remesh = True
 
 
 
 
153
  config.inference.decimate = decimate
154
  config.inference.mc_resolution = mc_resolution
155
  config.inference.batch_size = 8192
156
- denoise_param_path = os.path.join(config.output_dir, 'denoised.pt')
157
- primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
158
- model_primx.load_state_dict(primx_ckpt_weight)
159
  model_primx.to(device)
160
  model_primx.eval()
161
  with torch.no_grad():
@@ -179,6 +179,7 @@ _DESCRIPTION = '''
179
  block = gr.Blocks(title=_TITLE).queue()
180
  with block:
181
  current_fg_state = gr.State()
 
182
  with gr.Row():
183
  with gr.Column(scale=1):
184
  gr.Markdown('# ' + _TITLE)
@@ -192,17 +193,17 @@ with block:
192
  # background removal
193
  removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
194
  # inference steps
195
- input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
196
  # random seed
197
  input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=0.5, value=6, info="Typically CFG in a range of 4-7")
198
  # random seed
199
  input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
200
- # gen button
201
- button_gen = gr.Button(value="Generate", interactive=False)
202
  with gr.Row():
203
- input_mc_resolution = gr.Radio(choices=[64, 128, 256], label="MC Resolution", value=128, info="Cube resolution for mesh extraction")
204
  input_remesh = gr.Radio(choices=["No", "Yes"], label="Remesh", value="No", info="Remesh or not?")
205
- export_glb_btn = gr.Button(value="Export GLB", interactive=False)
 
 
206
 
207
  with gr.Column(scale=1):
208
  with gr.Row():
@@ -246,8 +247,8 @@ with block:
246
  )
247
 
248
  input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
249
- button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
250
- export_glb_btn.click(export_mesh, inputs=[input_remesh, input_mc_resolution], outputs=[output_glb, hdr_row])
251
 
252
  gr.Examples(
253
  examples=[
 
139
  recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
140
  visualize_video_primvolume(config.output_dir, batch, recon_param, 15, rm, device)
141
  prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
142
+ return output_rgb_video_path, output_prim_video_path, output_mat_video_path, gr.update(interactive=True), prim_params
143
 
144
+ def export_mesh(prim_params, uv_unwrap="Faster", remesh="No", mc_resolution=256, decimate=100000):
 
 
145
  # exporting GLB mesh
146
  output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
147
  if remesh == "No":
148
  config.inference.remesh = False
149
  elif remesh == "Yes":
150
  config.inference.remesh = True
151
+ if uv_unwrap == "Faster":
152
+ config.inference.fast_unwrap = True
153
+ elif uv_unwrap == "Better":
154
+ config.inference.fast_unwrap = False
155
  config.inference.decimate = decimate
156
  config.inference.mc_resolution = mc_resolution
157
  config.inference.batch_size = 8192
158
+ model_primx.load_state_dict(prim_params)
 
 
159
  model_primx.to(device)
160
  model_primx.eval()
161
  with torch.no_grad():
 
179
  block = gr.Blocks(title=_TITLE).queue()
180
  with block:
181
  current_fg_state = gr.State()
182
+ prim_param_state = gr.State()
183
  with gr.Row():
184
  with gr.Column(scale=1):
185
  gr.Markdown('# ' + _TITLE)
 
193
  # background removal
194
  removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
195
  # inference steps
196
+ input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps. Larger for robustness but slower.", value=25)
197
  # random seed
198
  input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=0.5, value=6, info="Typically CFG in a range of 4-7")
199
  # random seed
200
  input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
 
 
201
  with gr.Row():
202
+ input_mc_resolution = gr.Radio(choices=[64, 128, 256], label="MC Resolution", value=128, info="Cube resolution for mesh extraction. Larger for better quality but slower.")
203
  input_remesh = gr.Radio(choices=["No", "Yes"], label="Remesh", value="No", info="Remesh or not?")
204
+ input_unwrap = gr.Radio(choices=["Faster", "Better"], label="UV", value="Faster", info="UV unwrapping algorithm. Trade-off between quality and speed.")
205
+ # gen button
206
+ button_gen = gr.Button(value="Generate", interactive=False)
207
 
208
  with gr.Column(scale=1):
209
  with gr.Row():
 
247
  )
248
 
249
  input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
250
+ button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn, prim_param_state])
251
+ prim_param_state.change(export_mesh, inputs=[prim_param_state, input_unwrap, input_remesh, input_mc_resolution], outputs=[output_glb, hdr_row])
252
 
253
  gr.Examples(
254
  examples=[
configs/inference_dit.yml CHANGED
@@ -10,6 +10,7 @@ inference:
10
  seed: ${global_seed}
11
  precision: fp16
12
  export_glb: True
 
13
  decimate: 100000
14
  mc_resolution: 256
15
  batch_size: 4096
 
10
  seed: ${global_seed}
11
  precision: fp16
12
  export_glb: True
13
+ fast_unwrap: False
14
  decimate: 100000
15
  mc_resolution: 256
16
  batch_size: 4096
inference.py CHANGED
@@ -25,8 +25,9 @@ from scipy.ndimage import binary_dilation, binary_erosion
25
  from sklearn.neighbors import NearestNeighbors
26
  from utils.meshutils import clean_mesh, decimate_mesh
27
  from utils.mesh import Mesh
 
28
  logger = logging.getLogger("inference.py")
29
-
30
 
31
  def remove_background(image: PIL.Image.Image,
32
  rembg_session = None,
@@ -114,22 +115,31 @@ def extract_texmesh(args, model, output_path, device):
114
  w0 = 1024
115
  ssaa = 1
116
  fp16 = True
117
- glctx = dr.RasterizeCudaContext()
118
  v_np = vertices.astype(np.float32)
119
  f_np = triangles.astype(np.int64)
120
  v = torch.from_numpy(vertices).float().contiguous().to(device)
121
- f = torch.from_numpy(triangles.astype(np.int64)).int().contiguous().to(device)
122
- print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
123
- # unwrap uv in contracted space
124
- atlas = xatlas.Atlas()
125
- atlas.add_mesh(v_np, f_np)
126
- chart_options = xatlas.ChartOptions()
127
- chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
128
- pack_options = xatlas.PackOptions()
129
- # pack_options.blockAlign = True
130
- # pack_options.bruteForce = False
131
- atlas.generate(chart_options=chart_options, pack_options=pack_options)
132
- vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
 
 
 
 
 
 
 
 
 
 
133
 
134
  vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device)
135
  ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device)
@@ -143,8 +153,8 @@ def extract_texmesh(args, model, output_path, device):
143
  h, w = h0, w0
144
 
145
  rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
146
- xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
147
- mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
148
  # masked query
149
  xyzs = xyzs.view(-1, 3)
150
  mask = (mask > 0).view(-1)
@@ -182,15 +192,6 @@ def extract_texmesh(args, model, output_path, device):
182
  knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
183
  _, indices = knn.kneighbors(inpaint_coords)
184
  feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
185
- # do ssaa after the NN search, in numpy
186
- feats0 = cv2.cvtColor(feats[..., :3].astype(np.uint8), cv2.COLOR_RGB2BGR) # albedo
187
- feats1 = cv2.cvtColor(feats[..., 3:].astype(np.uint8), cv2.COLOR_RGB2BGR) # visibility features
188
- if ssaa > 1:
189
- feats0 = cv2.resize(feats0, (w0, h0), interpolation=cv2.INTER_LINEAR)
190
- feats1 = cv2.resize(feats1, (w0, h0), interpolation=cv2.INTER_LINEAR)
191
-
192
- cv2.imwrite(os.path.join(ins_dir, f'texture.jpg'), feats0)
193
- cv2.imwrite(os.path.join(ins_dir, f'roughness_metallic.jpg'), feats1)
194
 
195
  target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
196
  target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
 
25
  from sklearn.neighbors import NearestNeighbors
26
  from utils.meshutils import clean_mesh, decimate_mesh
27
  from utils.mesh import Mesh
28
+ from utils.uv_unwrap import box_projection_uv_unwrap, compute_vertex_normal
29
  logger = logging.getLogger("inference.py")
30
+ glctx = dr.RasterizeCudaContext()
31
 
32
  def remove_background(image: PIL.Image.Image,
33
  rembg_session = None,
 
115
  w0 = 1024
116
  ssaa = 1
117
  fp16 = True
 
118
  v_np = vertices.astype(np.float32)
119
  f_np = triangles.astype(np.int64)
120
  v = torch.from_numpy(vertices).float().contiguous().to(device)
121
+ f = torch.from_numpy(triangles.astype(np.int64)).to(torch.int64).contiguous().to(device)
122
+ if args.fast_unwrap:
123
+ print(f'[INFO] running box-based fast unwrapping to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
124
+ v_normal = compute_vertex_normal(v, f)
125
+ uv, indices = box_projection_uv_unwrap(v, v_normal, f, 0.02)
126
+ indv_v = v[f].reshape(-1, 3)
127
+ indv_faces = torch.arange(indv_v.shape[0], device=device, dtype=f.dtype).reshape(-1, 3)
128
+ uv_flat = uv[indices].reshape((-1, 2))
129
+ v = indv_v.contiguous()
130
+ f = indv_faces.contiguous()
131
+ ft_np = f.cpu().numpy()
132
+ vt_np = uv_flat.cpu().numpy()
133
+ else:
134
+ print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
135
+ # unwrap uv in contracted space
136
+ atlas = xatlas.Atlas()
137
+ atlas.add_mesh(v_np, f_np)
138
+ chart_options = xatlas.ChartOptions()
139
+ chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
140
+ pack_options = xatlas.PackOptions()
141
+ atlas.generate(chart_options=chart_options, pack_options=pack_options)
142
+ _, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
143
 
144
  vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device)
145
  ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device)
 
153
  h, w = h0, w0
154
 
155
  rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
156
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f.int()) # [1, h, w, 3]
157
+ mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f.int()) # [1, h, w, 1]
158
  # masked query
159
  xyzs = xyzs.view(-1, 3)
160
  mask = (mask > 0).view(-1)
 
192
  knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
193
  _, indices = knn.kneighbors(inpaint_coords)
194
  feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
 
 
 
 
 
 
 
 
 
195
 
196
  target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
197
  target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
requirements.txt CHANGED
@@ -21,4 +21,5 @@ diffusers==0.19.3
21
  ninja
22
  imageio
23
  imageio-ffmpeg
24
- gradio-litmodel3d==0.0.1
 
 
21
  ninja
22
  imageio
23
  imageio-ffmpeg
24
+ gradio-litmodel3d==0.0.1
25
+ jaxtyping==0.2.31
utils/uv_unwrap.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from jaxtyping import Bool, Float, Integer, Int, Num
7
+ from torch import Tensor
8
+
9
+ def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
10
+ # One pad for determinant
11
+ tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
12
+ det_tri = torch.det(tri_sq)
13
+ tri_rev = torch.cat(
14
+ (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
15
+ )
16
+ tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
17
+ return tri_sq
18
+
19
+ def triangle_intersection_2d(
20
+ t1: Float[Tensor, "*B 3 2"],
21
+ t2: Float[Tensor, "*B 3 2"],
22
+ eps=1e-12,
23
+ ) -> Float[Tensor, "*B"]: # noqa: F821
24
+ """Returns True if triangles collide, False otherwise"""
25
+
26
+ def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
27
+ logdetx = torch.logdet(x.double())
28
+ if eps is None:
29
+ return ~torch.isfinite(logdetx)
30
+ return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
31
+
32
+ t1s = tri_winding(t1)
33
+ t2s = tri_winding(t2)
34
+
35
+ # Assume the triangles do not collide in the begging
36
+ ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
37
+ for i in range(3):
38
+ edge = torch.roll(t1s, i, dims=1)[:, :2, :]
39
+ # Check if all points of triangle 2 lay on the external side of edge E.
40
+ # If this is the case the triangle do not collide
41
+ upd = (
42
+ chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
43
+ & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
44
+ & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
45
+ )
46
+ # Here no collision is still True due to inversion
47
+ ret = ret | upd
48
+
49
+ for i in range(3):
50
+ edge = torch.roll(t2s, i, dims=1)[:, :2, :]
51
+
52
+ upd = (
53
+ chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
54
+ & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
55
+ & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
56
+ )
57
+ # Here no collision is still True due to inversion
58
+ ret = ret | upd
59
+
60
+ return ~ret # Do the inversion
61
+
62
+ def dot(x, y, dim=-1):
63
+ return torch.sum(x * y, dim, keepdim=True)
64
+
65
+ def compute_vertex_normal(v_pos, t_pos_idx):
66
+ i0 = t_pos_idx[:, 0]
67
+ i1 = t_pos_idx[:, 1]
68
+ i2 = t_pos_idx[:, 2]
69
+ v0 = v_pos[i0, :]
70
+ v1 = v_pos[i1, :]
71
+ v2 = v_pos[i2, :]
72
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
73
+ # Splat face normals to vertices
74
+ v_nrm = torch.zeros_like(v_pos)
75
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
76
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
77
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
78
+ # Normalize, replace zero (degenerated) normals with some default value
79
+ v_nrm = torch.where(
80
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
81
+ )
82
+ v_nrm = F.normalize(v_nrm, dim=1)
83
+ if torch.is_anomaly_enabled():
84
+ assert torch.all(torch.isfinite(v_nrm))
85
+ return v_nrm
86
+
87
+ def _box_assign_vertex_to_cube_face(
88
+ vertex_positions: Float[Tensor, "Nv 3"],
89
+ vertex_normals: Float[Tensor, "Nv 3"],
90
+ triangle_idxs: Integer[Tensor, "Nf 3"],
91
+ bbox: Float[Tensor, "2 3"],
92
+ ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
93
+ # Test to not have a scaled model to fit the space better
94
+ # bbox_min = bbox[:1].mean(-1, keepdim=True)
95
+ # bbox_max = bbox[1:].mean(-1, keepdim=True)
96
+ # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
97
+
98
+ # Create a [0, 1] normalized vertex position
99
+ v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
100
+ # And to [-1, 1]
101
+ v_pos_normalized = 2.0 * v_pos_normalized - 1.0
102
+
103
+ # Get all vertex positions for each triangle
104
+ # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
105
+ v0 = v_pos_normalized[triangle_idxs[:, 0]]
106
+ v1 = v_pos_normalized[triangle_idxs[:, 1]]
107
+ v2 = v_pos_normalized[triangle_idxs[:, 2]]
108
+ tri_stack = torch.stack([v0, v1, v2], dim=1)
109
+
110
+ vn0 = vertex_normals[triangle_idxs[:, 0]]
111
+ vn1 = vertex_normals[triangle_idxs[:, 1]]
112
+ vn2 = vertex_normals[triangle_idxs[:, 2]]
113
+ tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
114
+
115
+ # Just average the normals per face
116
+ face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
117
+
118
+ # Now decide based on the face normal in which box map we project
119
+ # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
120
+ abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
121
+
122
+ axis = torch.tensor(
123
+ [
124
+ [1, 0, 0], # 0
125
+ [-1, 0, 0], # 1
126
+ [0, 1, 0], # 2
127
+ [0, -1, 0], # 3
128
+ [0, 0, 1], # 4
129
+ [0, 0, -1], # 5
130
+ ],
131
+ device=face_normal.device,
132
+ dtype=face_normal.dtype,
133
+ )
134
+ face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
135
+ index = face_normal_axis.argmax(-1)
136
+
137
+ max_axis, uc, vc = (
138
+ torch.ones_like(abs_x),
139
+ torch.zeros_like(tri_stack[..., :1]),
140
+ torch.zeros_like(tri_stack[..., :1]),
141
+ )
142
+ mask_pos_x = index == 0
143
+ max_axis[mask_pos_x] = abs_x[mask_pos_x]
144
+ uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
145
+ vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
146
+
147
+ mask_neg_x = index == 1
148
+ max_axis[mask_neg_x] = abs_x[mask_neg_x]
149
+ uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
150
+ vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
151
+
152
+ mask_pos_y = index == 2
153
+ max_axis[mask_pos_y] = abs_y[mask_pos_y]
154
+ uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
155
+ vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
156
+
157
+ mask_neg_y = index == 3
158
+ max_axis[mask_neg_y] = abs_y[mask_neg_y]
159
+ uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
160
+ vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
161
+
162
+ mask_pos_z = index == 4
163
+ max_axis[mask_pos_z] = abs_z[mask_pos_z]
164
+ uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
165
+ vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
166
+
167
+ mask_neg_z = index == 5
168
+ max_axis[mask_neg_z] = abs_z[mask_neg_z]
169
+ uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
170
+ vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
171
+
172
+ # UC from [-1, 1] to [0, 1]
173
+ max_dim_div = max_axis.max(dim=0, keepdims=True).values
174
+ uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
175
+ vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
176
+
177
+ uv = torch.stack([uc, vc], dim=-1)
178
+
179
+ return uv, index
180
+
181
+
182
+ def _assign_faces_uv_to_atlas_index(
183
+ vertex_positions: Float[Tensor, "Nv 3"],
184
+ triangle_idxs: Integer[Tensor, "Nf 3"],
185
+ face_uv: Float[Tensor, "Nf 3 2"],
186
+ face_index: Integer[Tensor, "Nf 3"],
187
+ ) -> Integer[Tensor, "Nf"]: # noqa: F821
188
+ triangle_pos = vertex_positions[triangle_idxs]
189
+ # We need to do perform 3 overlap checks.
190
+ # The first set is placed in the upper two thirds of the UV atlas.
191
+ # Conceptually, this is the direct visible surfaces from the each cube side
192
+ # The second set is placed in the lower thirds and the left half of the UV atlas.
193
+ # This is the first set of occluded surfaces. They will also be saved in the projected fashion
194
+ # The third pass finds all non assigned faces. They will be placed in the bottom right half of
195
+ # the UV atlas in scattered fashion.
196
+ assign_idx = face_index.clone()
197
+ for overlap_step in range(3):
198
+ overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
199
+ for i in range(overlap_step * 6, (overlap_step + 1) * 6):
200
+ mask = assign_idx == i
201
+ if not mask.any():
202
+ continue
203
+ # Get all elements belonging to the projection face
204
+ uv_triangle = face_uv[mask]
205
+ cur_triangle_pos = triangle_pos[mask]
206
+ # Find the center of the uv coordinates
207
+ center_uv = uv_triangle.mean(dim=1, keepdim=True)
208
+ # And also the radius of the triangle
209
+ uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
210
+
211
+ potentially_overlapping_mask = (
212
+ # Find all close triangles
213
+ (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
214
+ # Do not select the same element by offseting with an large valued identity matrix
215
+ + torch.eye(
216
+ uv_triangle.shape[0],
217
+ device=uv_triangle.device,
218
+ dtype=uv_triangle.dtype,
219
+ ).unsqueeze(-1)
220
+ * 1000
221
+ )
222
+ # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
223
+ potentially_overlapping_mask = (
224
+ potentially_overlapping_mask
225
+ <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
226
+ ).squeeze(-1)
227
+ overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
228
+
229
+ # Only unique triangles (A|B and B|A should be the same)
230
+ f = torch.min(overlap_coords, dim=-1).values
231
+ s = torch.max(overlap_coords, dim=-1).values
232
+ overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
233
+ first, second = overlap_coords.unbind(-1)
234
+
235
+ # Get the triangles
236
+ tri_1 = uv_triangle[first]
237
+ tri_2 = uv_triangle[second]
238
+
239
+ # Perform the actual set with the reduced number of potentially overlapping triangles
240
+ its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
241
+
242
+ # So we now need to detect which triangles are the occluded ones.
243
+ # We always assume the first to be the visible one (the others should move)
244
+ # In the previous step we use a lexigraphical sort to get the unique pairs
245
+ # In this we use a sort based on the orthographic projection
246
+ ax = 0 if i < 2 else 1 if i < 4 else 2
247
+ use_max = i % 2 == 1
248
+
249
+ tri1_c = cur_triangle_pos[first].mean(dim=1)
250
+ tri2_c = cur_triangle_pos[second].mean(dim=1)
251
+
252
+ mark_first = (
253
+ (tri1_c[..., ax] > tri2_c[..., ax])
254
+ if use_max
255
+ else (tri1_c[..., ax] < tri2_c[..., ax])
256
+ )
257
+ first[mark_first] = second[mark_first]
258
+
259
+ # Lastly the same index can be tested multiple times.
260
+ # If one marks it as overlapping we keep it marked as such.
261
+ # We do this by testing if it has been marked at least once.
262
+ unique_idx, rev_idx = torch.unique(first, return_inverse=True)
263
+
264
+ add = torch.zeros_like(unique_idx, dtype=torch.float32)
265
+ add.index_add_(0, rev_idx, its.float())
266
+ its_mask = add > 0
267
+
268
+ # And fill it in the overlapping indicator
269
+ idx = torch.where(mask)[0][unique_idx]
270
+ overlapping_indicator[idx] = its_mask
271
+
272
+ # Move the index to the overlap regions (shift by 6)
273
+ assign_idx[overlapping_indicator] += 6
274
+
275
+ # We do not care about the correct face placement after the first 2 slices
276
+ max_idx = 6 * 2
277
+ return assign_idx.clamp(0, max_idx)
278
+
279
+
280
+ def _find_slice_offset_and_scale(
281
+ index: Integer[Tensor, "Nf"], # noqa: F821
282
+ ) -> Tuple[
283
+ Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
284
+ ]: # noqa: F821
285
+ # 6 due to the 6 cube faces
286
+ off = 1 / 3
287
+ dupl_off = 1 / 6
288
+
289
+ # Here, we need to decide how to pack the textures in the case of overlap
290
+ def x_offset_calc(x, i):
291
+ offset_calc = i // 6
292
+ # Initial coordinates - just 3x2 grid
293
+ if offset_calc == 0:
294
+ return off * x
295
+ else:
296
+ # Smaller 3x2 grid plus eventual shift to right for
297
+ # second overlap
298
+ return dupl_off * x + min(offset_calc - 1, 1) * 0.5
299
+
300
+ def y_offset_calc(x, i):
301
+ offset_calc = i // 6
302
+ # Initial coordinates - just a 3x2 grid
303
+ if offset_calc == 0:
304
+ return off * x
305
+ else:
306
+ # Smaller coordinates in the lowest row
307
+ return dupl_off * x + off * 2
308
+
309
+ offset_x = torch.zeros_like(index, dtype=torch.float32)
310
+ offset_y = torch.zeros_like(index, dtype=torch.float32)
311
+ offset_x_vals = [0, 1, 2, 0, 1, 2]
312
+ offset_y_vals = [0, 0, 0, 1, 1, 1]
313
+ for i in range(index.max().item() + 1):
314
+ mask = index == i
315
+ if not mask.any():
316
+ continue
317
+ offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
318
+ offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
319
+
320
+ div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
321
+ # All overlap elements are saved in half scale
322
+ div_x[index >= 6] = 6
323
+ div_y = div_x.clone() # Same for y
324
+ # Except for the random overlaps
325
+ div_x[index >= 12] = 2
326
+ # But the random overlaps are saved in a large block in the lower thirds
327
+ div_y[index >= 12] = 3
328
+
329
+ return offset_x, offset_y, div_x, div_y
330
+
331
+
332
+ def rotation_flip_matrix_2d(
333
+ rad: float, flip_x: bool = False, flip_y: bool = False
334
+ ) -> Float[Tensor, "2 2"]:
335
+ cos = math.cos(rad)
336
+ sin = math.sin(rad)
337
+ rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
338
+ flip_mat = torch.tensor(
339
+ [
340
+ [-1 if flip_x else 1, 0],
341
+ [0, -1 if flip_y else 1],
342
+ ],
343
+ dtype=torch.float32,
344
+ )
345
+
346
+ return flip_mat @ rot_mat
347
+
348
+
349
+ def calculate_tangents(
350
+ vertex_positions: Float[Tensor, "Nv 3"],
351
+ vertex_normals: Float[Tensor, "Nv 3"],
352
+ triangle_idxs: Integer[Tensor, "Nf 3"],
353
+ face_uv: Float[Tensor, "Nf 3 2"],
354
+ ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
355
+ vn_idx = [None] * 3
356
+ pos = [None] * 3
357
+ tex = face_uv.unbind(1)
358
+ for i in range(0, 3):
359
+ pos[i] = vertex_positions[triangle_idxs[:, i]]
360
+ # t_nrm_idx is always the same as t_pos_idx
361
+ vn_idx[i] = triangle_idxs[:, i]
362
+
363
+ tangents = torch.zeros_like(vertex_normals)
364
+ tansum = torch.zeros_like(vertex_normals)
365
+
366
+ # Compute tangent space for each triangle
367
+ duv1 = tex[1] - tex[0]
368
+ duv2 = tex[2] - tex[0]
369
+ dpos1 = pos[1] - pos[0]
370
+ dpos2 = pos[2] - pos[0]
371
+
372
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
373
+
374
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
375
+
376
+ # Avoid division by zero for degenerated texture coordinates
377
+ denom_safe = denom.clip(1e-6)
378
+ tang = tng_nom / denom_safe
379
+
380
+ # Update all 3 vertices
381
+ for i in range(0, 3):
382
+ idx = vn_idx[i][:, None].repeat(1, 3)
383
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
384
+ tansum.scatter_add_(
385
+ 0, idx, torch.ones_like(tang)
386
+ ) # tansum[n_i] = tansum[n_i] + 1
387
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
388
+ # triangles influence the tangent space more
389
+ tangents = tangents / tansum
390
+
391
+ # Normalize and make sure tangent is perpendicular to normal
392
+ tangents = F.normalize(tangents, dim=1)
393
+ tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
394
+
395
+ return tangents
396
+
397
+
398
+ def _rotate_uv_slices_consistent_space(
399
+ vertex_positions: Float[Tensor, "Nv 3"],
400
+ vertex_normals: Float[Tensor, "Nv 3"],
401
+ triangle_idxs: Integer[Tensor, "Nf 3"],
402
+ uv: Float[Tensor, "Nf 3 2"],
403
+ index: Integer[Tensor, "Nf"], # noqa: F821
404
+ ):
405
+ tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
406
+ pos_stack = torch.stack(
407
+ [
408
+ -vertex_positions[..., 1],
409
+ vertex_positions[..., 0],
410
+ torch.zeros_like(vertex_positions[..., 0]),
411
+ ],
412
+ dim=-1,
413
+ )
414
+ expected_tangents = F.normalize(
415
+ torch.linalg.cross(
416
+ vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
417
+ ),
418
+ -1,
419
+ )
420
+
421
+ actual_tangents = tangents[triangle_idxs]
422
+ expected_tangents = expected_tangents[triangle_idxs]
423
+
424
+ def rotation_matrix_2d(theta):
425
+ c, s = torch.cos(theta), torch.sin(theta)
426
+ return torch.tensor([[c, -s], [s, c]])
427
+
428
+ # Now find the rotation
429
+ index_mod = index % 6 # Shouldn't happen. Just for safety
430
+ for i in range(6):
431
+ mask = index_mod == i
432
+ if not mask.any():
433
+ continue
434
+
435
+ actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
436
+ expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
437
+
438
+ dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
439
+ cross_product = (
440
+ actual_mean_tangent[0] * expected_mean_tangent[1]
441
+ - actual_mean_tangent[1] * expected_mean_tangent[0]
442
+ )
443
+ angle = torch.atan2(cross_product, dot_product)
444
+
445
+ rot_matrix = rotation_matrix_2d(angle).to(mask.device)
446
+ # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
447
+ uv_cur = uv[mask] * 2 - 1 # Center it first
448
+ # Rotate it
449
+ uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
450
+
451
+ # Rescale uv[mask] to be within the 0-1 range
452
+ uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
453
+
454
+ return uv
455
+
456
+
457
+ def _handle_slice_uvs(
458
+ uv: Float[Tensor, "Nf 3 2"],
459
+ index: Integer[Tensor, "Nf"], # noqa: F821
460
+ island_padding: float,
461
+ max_index: int = 6 * 2,
462
+ ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
463
+ uc, vc = uv.unbind(-1)
464
+
465
+ # Get the second slice (The first overlap)
466
+ index_filter = [index == i for i in range(6, max_index)]
467
+
468
+ # Normalize them to always fully fill the atlas patch
469
+ for i, fi in enumerate(index_filter):
470
+ if fi.sum() > 0:
471
+ # Scale the slice but only up to a factor of 2
472
+ # This keeps the texture resolution with the first slice in line (Half space in UV)
473
+ uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
474
+ vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
475
+
476
+ uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
477
+ vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
478
+
479
+ return torch.stack([uc_padded, vc_padded], dim=-1)
480
+
481
+
482
+ def _handle_remaining_uvs(
483
+ uv: Float[Tensor, "Nf 3 2"],
484
+ index: Integer[Tensor, "Nf"], # noqa: F821
485
+ island_padding: float,
486
+ ) -> Float[Tensor, "Nf 3 2"]:
487
+ uc, vc = uv.unbind(-1)
488
+ # Get all remaining elements
489
+ remaining_filter = index >= 6 * 2
490
+ squares_left = remaining_filter.sum()
491
+
492
+ if squares_left == 0:
493
+ return uv
494
+
495
+ uc = uc[remaining_filter]
496
+ vc = vc[remaining_filter]
497
+
498
+ # Or remaining triangles are distributed in a rectangle
499
+ # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
500
+ ratio = 0.5 * (1 / 3) # 1.5
501
+ # sqrt(744/(0.5*(1/3)))
502
+
503
+ mult = math.sqrt(squares_left / ratio)
504
+ num_square_width = int(math.ceil(0.5 * mult))
505
+ num_square_height = int(math.ceil(squares_left / num_square_width))
506
+
507
+ width = 1 / num_square_width
508
+ height = 1 / num_square_height
509
+
510
+ # The idea is again to keep the texture resolution consistent with the first slice
511
+ # This only occupys half the region in the texture chart but the scaling on the squares
512
+ # assumes full coverage.
513
+ clip_val = min(width, height) * 1.5
514
+ # Now normalize the UVs with taking into account the maximum scaling
515
+ uc = (uc - uc.min(dim=1, keepdim=True).values) / (
516
+ uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
517
+ ).clip(clip_val)
518
+ vc = (vc - vc.min(dim=1, keepdim=True).values) / (
519
+ vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
520
+ ).clip(clip_val)
521
+ # Add a small padding
522
+ uc = (
523
+ uc * (1 - island_padding * num_square_width * 0.5)
524
+ + island_padding * num_square_width * 0.25
525
+ ).clip(0, 1)
526
+ vc = (
527
+ vc * (1 - island_padding * num_square_height * 0.5)
528
+ + island_padding * num_square_height * 0.25
529
+ ).clip(0, 1)
530
+
531
+ uc = uc * width
532
+ vc = vc * height
533
+
534
+ # And calculate offsets for each element
535
+ idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
536
+ x_idx = idx % num_square_width
537
+ y_idx = idx // num_square_width
538
+ # And move each triangle to its own spot
539
+ uc = uc + x_idx[:, None] * width
540
+ vc = vc + y_idx[:, None] * height
541
+
542
+ uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
543
+ vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
544
+
545
+ uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
546
+
547
+ return uv
548
+
549
+
550
+ def _distribute_individual_uvs_in_atlas(
551
+ face_uv: Float[Tensor, "Nf 3 2"],
552
+ assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
553
+ offset_x: Float[Tensor, "Nf"], # noqa: F821
554
+ offset_y: Float[Tensor, "Nf"], # noqa: F821
555
+ div_x: Float[Tensor, "Nf"], # noqa: F821
556
+ div_y: Float[Tensor, "Nf"], # noqa: F821
557
+ island_padding: float,
558
+ ):
559
+ # Place the slice first
560
+ placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
561
+ # Then handle the remaining overlap elements
562
+ placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
563
+
564
+ uc, vc = placed_uv.unbind(-1)
565
+ uc = uc / div_x[:, None] + offset_x[:, None]
566
+ vc = vc / div_y[:, None] + offset_y[:, None]
567
+
568
+ uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
569
+
570
+ return uv
571
+
572
+
573
+ def _get_unique_face_uv(
574
+ uv: Float[Tensor, "Nf 3 2"],
575
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
576
+ unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
577
+ # And add the face to uv index mapping
578
+ vtex_idx = unique_idx.view(-1, 3)
579
+
580
+ return unique_uv, vtex_idx
581
+
582
+
583
+ def _align_mesh_with_main_axis(
584
+ vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
585
+ ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
586
+ # Use pca to find the 2 main axis (third is derived by cross product)
587
+ # Set the random seed so it's repeatable
588
+ torch.manual_seed(0)
589
+ _, _, v = torch.pca_lowrank(vertex_positions, q=2)
590
+ main_axis, seconday_axis = v[:, 0], v[:, 1]
591
+
592
+ main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
593
+ # Orthogonalize the second axis
594
+ seconday_axis: Float[Tensor, "3"] = F.normalize(
595
+ seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
596
+ )
597
+ # Create perpendicular third axis
598
+ third_axis: Float[Tensor, "3"] = F.normalize(
599
+ torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
600
+ )
601
+
602
+ # Check to which canonical axis each aligns
603
+ main_axis_max_idx = main_axis.abs().argmax().item()
604
+ seconday_axis_max_idx = seconday_axis.abs().argmax().item()
605
+ third_axis_max_idx = third_axis.abs().argmax().item()
606
+
607
+ # Now sort the axes based on the argmax so they align with thecanonoical axes
608
+ # If two axes have the same argmax move one of them
609
+ all_possible_axis = {0, 1, 2}
610
+ cur_index = 1
611
+ while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
612
+ # Find missing axis
613
+ missing_axis = all_possible_axis - set(
614
+ [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
615
+ )
616
+ missing_axis = missing_axis.pop()
617
+ # Just assign it to third axis as it had the smallest contribution to the
618
+ # overall shape
619
+ if cur_index == 1:
620
+ third_axis_max_idx = missing_axis
621
+ elif cur_index == 2:
622
+ seconday_axis_max_idx = missing_axis
623
+ else:
624
+ raise ValueError("Could not find 3 unique axis")
625
+ cur_index += 1
626
+
627
+ if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
628
+ raise ValueError("Could not find 3 unique axis")
629
+
630
+ axes = [None] * 3
631
+ axes[main_axis_max_idx] = main_axis
632
+ axes[seconday_axis_max_idx] = seconday_axis
633
+ axes[third_axis_max_idx] = third_axis
634
+ # Create rotation matrix from the individual axes
635
+ rot_mat = torch.stack(axes, dim=1).T
636
+
637
+ # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
638
+ vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
639
+ vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
640
+
641
+ return vertex_positions, vertex_normals
642
+
643
+
644
+ def box_projection_uv_unwrap(
645
+ vertex_positions: Float[Tensor, "Nv 3"],
646
+ vertex_normals: Float[Tensor, "Nv 3"],
647
+ triangle_idxs: Integer[Tensor, "Nf 3"],
648
+ island_padding: float,
649
+ ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
650
+ # Align the mesh with main axis directions first
651
+ # vertex_positions, vertex_normals = _align_mesh_with_main_axis(
652
+ # vertex_positions, vertex_normals
653
+ # )
654
+
655
+ bbox: Float[Tensor, "2 3"] = torch.stack(
656
+ [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
657
+ )
658
+ # First decide in which cube face the triangle is placed
659
+ face_uv, face_index = _box_assign_vertex_to_cube_face(
660
+ vertex_positions, vertex_normals, triangle_idxs, bbox
661
+ )
662
+
663
+ # Rotate the UV islands in a way that they align with the radial z tangent space
664
+ face_uv = _rotate_uv_slices_consistent_space(
665
+ vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
666
+ )
667
+
668
+ # Then find where where the face is placed in the atlas.
669
+ # This has to detect potential overlaps
670
+ assigned_atlas_index = _assign_faces_uv_to_atlas_index(
671
+ vertex_positions, triangle_idxs, face_uv, face_index
672
+ )
673
+
674
+ # Then figure out the final place in the atlas based on the assignment
675
+ offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
676
+ assigned_atlas_index
677
+ )
678
+
679
+ # Next distribute the faces in the uv atlas
680
+ placed_uv = _distribute_individual_uvs_in_atlas(
681
+ face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
682
+ )
683
+
684
+ # And get the unique per-triangle UV coordinates
685
+ return _get_unique_face_uv(placed_uv)