bluestyle97 commited on
Commit
8e25beb
·
verified ·
1 Parent(s): 3472618

Create mesh_optim.py

Browse files
Files changed (1) hide show
  1. freesplatter/utils/mesh_optim.py +203 -0
freesplatter/utils/mesh_optim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import numpy as np
3
+ import torch
4
+ import utils3d
5
+ import nvdiffrast.torch as dr
6
+ from tqdm import tqdm
7
+ import trimesh
8
+ import trimesh.visual
9
+ import xatlas
10
+ import cv2
11
+ from PIL import Image
12
+ import fast_simplification
13
+
14
+ from freesplatter.utils.mesh import Mesh
15
+
16
+
17
+ def parametrize_mesh(vertices: np.array, faces: np.array):
18
+ """
19
+ Parametrize a mesh to a texture space, using xatlas.
20
+ Args:
21
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
22
+ faces (np.array): Faces of the mesh. Shape (F, 3).
23
+ """
24
+
25
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
26
+
27
+ vertices = vertices[vmapping]
28
+ faces = indices
29
+
30
+ return vertices, faces, uvs
31
+
32
+
33
+ def bake_texture(
34
+ vertices: np.array,
35
+ faces: np.array,
36
+ uvs: np.array,
37
+ observations: List[np.array],
38
+ masks: List[np.array],
39
+ extrinsics: List[np.array],
40
+ intrinsics: List[np.array],
41
+ texture_size: int = 2048,
42
+ near: float = 0.1,
43
+ far: float = 10.0,
44
+ mode: Literal['fast', 'opt'] = 'opt',
45
+ lambda_tv: float = 1e-2,
46
+ verbose: bool = False,
47
+ ):
48
+ """
49
+ Bake texture to a mesh from multiple observations.
50
+ Args:
51
+ vertices (np.array): Vertices of the mesh. Shape (V, 3).
52
+ faces (np.array): Faces of the mesh. Shape (F, 3).
53
+ uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
54
+ observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
55
+ masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
56
+ extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
57
+ intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
58
+ texture_size (int): Size of the texture.
59
+ near (float): Near plane of the camera.
60
+ far (float): Far plane of the camera.
61
+ mode (Literal['fast', 'opt']): Mode of texture baking.
62
+ lambda_tv (float): Weight of total variation loss in optimization.
63
+ verbose (bool): Whether to print progress.
64
+ """
65
+ vertices = torch.tensor(vertices).float().cuda()
66
+ faces = torch.tensor(faces.astype(np.int32)).cuda()
67
+ uvs = torch.tensor(uvs).float().cuda()
68
+ observations = [torch.tensor(obs).float().cuda() for obs in observations]
69
+ masks = [torch.tensor(m>1e-2).bool().cuda() for m in masks]
70
+ views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).float().cuda()) for extr in extrinsics]
71
+ projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).float().cuda(), near, far) for intr in intrinsics]
72
+
73
+ if mode == 'fast':
74
+ texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
75
+ texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
76
+ rastctx = utils3d.torch.RastContext(backend='cuda')
77
+ for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
78
+ with torch.no_grad():
79
+ rast = utils3d.torch.rasterize_triangle_faces(
80
+ rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
81
+ )
82
+ uv_map = rast['uv'][0].detach().flip(0)
83
+ mask = rast['mask'][0].detach().bool() & masks[0]
84
+
85
+ # nearest neighbor interpolation
86
+ uv_map = (uv_map * texture_size).floor().long()
87
+ obs = observation[mask]
88
+ uv_map = uv_map[mask]
89
+ idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
90
+ texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
91
+ texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
92
+
93
+ mask = texture_weights > 0
94
+ texture[mask] /= texture_weights[mask][:, None]
95
+ texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
96
+
97
+ # inpaint
98
+ mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
99
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
100
+
101
+ elif mode == 'opt':
102
+ rastctx = utils3d.torch.RastContext(backend='cuda')
103
+ observations = [observations.flip(0) for observations in observations]
104
+ masks = [m.flip(0) for m in masks]
105
+ _uv = []
106
+ _uv_dr = []
107
+ for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
108
+ with torch.no_grad():
109
+ rast = utils3d.torch.rasterize_triangle_faces(
110
+ rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
111
+ )
112
+ _uv.append(rast['uv'].detach())
113
+ _uv_dr.append(rast['uv_dr'].detach())
114
+
115
+ texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
116
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
117
+
118
+ def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
119
+ return start_lr * (end_lr / start_lr) ** (step / total_steps)
120
+
121
+ def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
122
+ return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
123
+
124
+ def tv_loss(texture):
125
+ return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
126
+ torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
127
+
128
+ total_steps = 2500
129
+ with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
130
+ for step in range(total_steps):
131
+ optimizer.zero_grad()
132
+ selected = np.random.randint(0, len(views))
133
+ uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
134
+ render = dr.texture(texture, uv, uv_dr)[0]
135
+ loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
136
+ if lambda_tv > 0:
137
+ loss += lambda_tv * tv_loss(texture)
138
+ loss.backward()
139
+ optimizer.step()
140
+ # annealing
141
+ optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
142
+ pbar.set_postfix({'loss': loss.item()})
143
+ pbar.update()
144
+ texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
145
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
146
+ rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
147
+ )['mask'][0].detach().cpu().numpy().astype(np.uint8)
148
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
149
+ else:
150
+ raise ValueError(f'Unknown mode: {mode}')
151
+
152
+ return texture
153
+
154
+
155
+ def optimize_mesh(
156
+ mesh: Mesh,
157
+ images: torch.Tensor,
158
+ masks: torch.Tensor,
159
+ extrinsics: torch.Tensor,
160
+ intrinsics: torch.Tensor,
161
+ simplify: float = 0.95,
162
+ texture_size: int = 1024,
163
+ verbose: bool = False,
164
+ ) -> trimesh.Trimesh:
165
+ """
166
+ Convert a generated asset to a glb file.
167
+ Args:
168
+ mesh (Mesh): Extracted mesh.
169
+ simplify (float): Ratio of faces to remove in simplification.
170
+ texture_size (int): Size of the texture.
171
+ verbose (bool): Whether to print progress.
172
+ """
173
+ vertices = mesh.v.cpu().numpy()
174
+ faces = mesh.f.cpu().numpy()
175
+
176
+ # mesh simplification
177
+ max_faces = 50000
178
+ mesh_reduction = max(1 - max_faces / faces.shape[0], simplify)
179
+ vertices, faces = fast_simplification.simplify(
180
+ vertices, faces, target_reduction=mesh_reduction)
181
+
182
+ # parametrize mesh
183
+ vertices, faces, uvs = parametrize_mesh(vertices, faces)
184
+
185
+ # bake texture
186
+ images = [images[i].cpu().numpy() for i in range(len(images))]
187
+ masks = [masks[i].cpu().numpy() for i in range(len(masks))]
188
+ extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
189
+ intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
190
+ texture = bake_texture(
191
+ vertices.astype(float), faces.astype(float), uvs,
192
+ images, masks, extrinsics, intrinsics,
193
+ texture_size=texture_size,
194
+ mode='opt',
195
+ lambda_tv=0.01,
196
+ verbose=verbose
197
+ )
198
+ texture = Image.fromarray(texture)
199
+
200
+ # rotate mesh
201
+ vertices = vertices.astype(float) @ np.array([[-1, 0, 0], [0, 0, 1], [0, 1, 0]]).astype(float)
202
+ mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture))
203
+ return mesh