Spaces:
Sleeping
Sleeping
feat: Add gs_utils for gs export
Browse files- app.py +20 -41
- gs_utils.py +106 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -17,7 +17,7 @@ from torchvision import transforms
|
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
-
|
21 |
|
22 |
# Default values
|
23 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
@@ -29,15 +29,8 @@ OPENGL = np.array([[1, 0, 0, 0],
|
|
29 |
[0, 0, -1, 0],
|
30 |
[0, 0, 0, 1]])
|
31 |
|
32 |
-
def export_geometry(geometry
|
33 |
-
|
34 |
-
if not isinstance(geometry, o3d.geometry.PointCloud):
|
35 |
-
raise ValueError("Expected an Open3D PointCloud object when as_pointcloud is True")
|
36 |
-
output_path = tempfile.mktemp(suffix='.ply')
|
37 |
-
else:
|
38 |
-
if not isinstance(geometry, o3d.geometry.TriangleMesh):
|
39 |
-
raise ValueError("Expected an Open3D TriangleMesh object when as_pointcloud is False")
|
40 |
-
output_path = tempfile.mktemp(suffix='.obj')
|
41 |
|
42 |
# Apply rotation
|
43 |
rot = np.eye(4)
|
@@ -45,11 +38,7 @@ def export_geometry(geometry, as_pointcloud=False):
|
|
45 |
transform = np.linalg.inv(OPENGL @ rot)
|
46 |
geometry.transform(transform)
|
47 |
|
48 |
-
|
49 |
-
if as_pointcloud:
|
50 |
-
o3d.io.write_point_cloud(output_path, geometry, write_ascii=False, compressed=True)
|
51 |
-
else:
|
52 |
-
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
|
53 |
|
54 |
return output_path
|
55 |
|
@@ -176,7 +165,7 @@ def generate_mask(image: np.ndarray):
|
|
176 |
return mask_np
|
177 |
@torch.no_grad()
|
178 |
def reconstruct(video_path, conf_thresh, kf_every,
|
179 |
-
|
180 |
# Extract frames from video
|
181 |
demo_path = extract_frames(video_path)
|
182 |
|
@@ -220,31 +209,21 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
220 |
pcds.append(pcd)
|
221 |
|
222 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
223 |
-
|
224 |
-
if as_pointcloud:
|
225 |
-
o3d_geometry = pcd_combined
|
226 |
-
else:
|
227 |
-
o3d_geometry = point2mesh(pcd_combined)
|
228 |
|
229 |
# Create coarse result
|
230 |
-
coarse_output_path = export_geometry(o3d_geometry
|
231 |
|
232 |
yield coarse_output_path, None
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
o3d_geometry = point2mesh(transformed_pcds)
|
243 |
-
|
244 |
-
# Create coarse result
|
245 |
-
refined_output_path = export_geometry(o3d_geometry, as_pointcloud)
|
246 |
-
|
247 |
-
yield coarse_output_path, refined_output_path
|
248 |
|
249 |
# Clean up temporary directory
|
250 |
os.system(f"rm -rf {demo_path}")
|
@@ -320,19 +299,19 @@ with gr.Blocks(
|
|
320 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
321 |
with gr.Row():
|
322 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
323 |
-
refine = gr.Checkbox(label="Enable Backend", value=False)
|
324 |
-
as_pointcloud = gr.Checkbox(label="As Pointcloud", value=False)
|
325 |
reconstruct_btn = gr.Button("Reconstruct")
|
326 |
|
327 |
with gr.Column(scale=2):
|
328 |
with gr.Tab("Coarse Model"):
|
329 |
-
coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid",
|
|
|
330 |
with gr.Tab("Refined Model"):
|
331 |
-
refined_model = gr.Model3D(label="Refined
|
|
|
332 |
|
333 |
reconstruct_btn.click(
|
334 |
fn=reconstruct,
|
335 |
-
inputs=[video_input, conf_thresh, kf_every,
|
336 |
outputs=[coarse_model, refined_model]
|
337 |
)
|
338 |
|
|
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
+
from gs_utils import point2gs
|
21 |
|
22 |
# Default values
|
23 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
|
|
29 |
[0, 0, -1, 0],
|
30 |
[0, 0, 0, 1]])
|
31 |
|
32 |
+
def export_geometry(geometry):
|
33 |
+
output_path = tempfile.mktemp(suffix='.obj')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# Apply rotation
|
36 |
rot = np.eye(4)
|
|
|
38 |
transform = np.linalg.inv(OPENGL @ rot)
|
39 |
geometry.transform(transform)
|
40 |
|
41 |
+
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
|
|
|
|
|
|
|
|
|
42 |
|
43 |
return output_path
|
44 |
|
|
|
165 |
return mask_np
|
166 |
@torch.no_grad()
|
167 |
def reconstruct(video_path, conf_thresh, kf_every,
|
168 |
+
remove_background=False):
|
169 |
# Extract frames from video
|
170 |
demo_path = extract_frames(video_path)
|
171 |
|
|
|
209 |
pcds.append(pcd)
|
210 |
|
211 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
212 |
+
o3d_geometry = point2mesh(pcd_combined)
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# Create coarse result
|
215 |
+
coarse_output_path = export_geometry(o3d_geometry)
|
216 |
|
217 |
yield coarse_output_path, None
|
218 |
|
219 |
+
# Perform global optimization
|
220 |
+
print("Performing global registration...")
|
221 |
+
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
222 |
+
|
223 |
+
# Create coarse result
|
224 |
+
refined_output_path = tempfile.mktemp(suffix='.ply')
|
225 |
+
point2gs(refined_output_path, transformed_pcds)
|
226 |
+
yield coarse_output_path, refined_output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
# Clean up temporary directory
|
229 |
os.system(f"rm -rf {demo_path}")
|
|
|
299 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
300 |
with gr.Row():
|
301 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
|
|
|
|
302 |
reconstruct_btn = gr.Button("Reconstruct")
|
303 |
|
304 |
with gr.Column(scale=2):
|
305 |
with gr.Tab("Coarse Model"):
|
306 |
+
coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid",
|
307 |
+
clear_color=[0.0, 0.0, 0.0, 0.0])
|
308 |
with gr.Tab("Refined Model"):
|
309 |
+
refined_model = gr.Model3D(label="Refined Gaussian Splatting", display_mode="solid",
|
310 |
+
clear_color=[0.0, 0.0, 0.0, 0.0])
|
311 |
|
312 |
reconstruct_btn.click(
|
313 |
fn=reconstruct,
|
314 |
+
inputs=[video_input, conf_thresh, kf_every, remove_background],
|
315 |
outputs=[coarse_model, refined_model]
|
316 |
)
|
317 |
|
gs_utils.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from plyfile import PlyElement, PlyData
|
4 |
+
import open3d as o3d
|
5 |
+
|
6 |
+
def get_f_dc(colors):
|
7 |
+
return RGB2SH(colors)[:, :, np.newaxis]
|
8 |
+
|
9 |
+
def get_f_rest(points, max_sh_degree=3):
|
10 |
+
f_rest_shape = (points.shape[0], (max_sh_degree + 1) ** 2 - 1, 3)
|
11 |
+
return np.zeros(f_rest_shape)
|
12 |
+
|
13 |
+
def get_opacity(points):
|
14 |
+
return inverse_sigmoid(0.5 * np.ones((points.shape[0], 1)))
|
15 |
+
|
16 |
+
def get_scales(points):
|
17 |
+
scales = np.ones((points.shape[0], 3)) * 0.0015
|
18 |
+
scales[:, 2] = 1e-6
|
19 |
+
|
20 |
+
return np.log(scales)
|
21 |
+
|
22 |
+
def get_rotation(normals):
|
23 |
+
if normals is not None and np.any(normals):
|
24 |
+
return normal2rotation(normals)
|
25 |
+
else:
|
26 |
+
return np.zeros((normals.shape[0], 4))
|
27 |
+
|
28 |
+
def RGB2SH(rgb):
|
29 |
+
return (rgb - 0.5) / 0.28209479177387814
|
30 |
+
|
31 |
+
def inverse_sigmoid(x):
|
32 |
+
return np.log(x / (1 - x))
|
33 |
+
|
34 |
+
def normal2rotation(n):
|
35 |
+
n = n / np.linalg.norm(n, axis=1, keepdims=True)
|
36 |
+
w0 = np.tile([[1, 0, 0]], (n.shape[0], 1))
|
37 |
+
R0 = w0 - np.sum(w0 * n, axis=1, keepdims=True) * n
|
38 |
+
R0 *= np.sign(R0[:, :1])
|
39 |
+
R0 /= np.linalg.norm(R0, axis=1, keepdims=True)
|
40 |
+
R1 = np.cross(n, R0)
|
41 |
+
R1 *= np.sign(R1[:, 1:2]) * np.sign(n[:, 2:])
|
42 |
+
R = np.stack([R0, R1, n], axis=-1)
|
43 |
+
q = rotmat2quaternion(R)
|
44 |
+
return q
|
45 |
+
|
46 |
+
def rotmat2quaternion(R, normalize=False):
|
47 |
+
tr = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + 1e-6
|
48 |
+
r = np.sqrt(1 + tr) / 2
|
49 |
+
q = np.stack([
|
50 |
+
r,
|
51 |
+
(R[:, 2, 1] - R[:, 1, 2]) / (4 * r),
|
52 |
+
(R[:, 0, 2] - R[:, 2, 0]) / (4 * r),
|
53 |
+
(R[:, 1, 0] - R[:, 0, 1]) / (4 * r)
|
54 |
+
], axis=-1)
|
55 |
+
if normalize:
|
56 |
+
q /= np.linalg.norm(q, axis=-1, keepdims=True)
|
57 |
+
return q
|
58 |
+
|
59 |
+
def point2gs(path, pcd, scale=None, max_sh_degree=1):
|
60 |
+
# Ensure the directory exists
|
61 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
62 |
+
|
63 |
+
# Get point cloud data
|
64 |
+
xyz = np.asarray(pcd.points)
|
65 |
+
normals = np.asarray(pcd.normals) if pcd.has_normals() else np.zeros_like(xyz)
|
66 |
+
colors = np.asarray(pcd.colors) if pcd.has_colors() else np.ones_like(xyz)
|
67 |
+
|
68 |
+
# Generate additional attributes
|
69 |
+
f_dc = get_f_dc(colors).reshape(xyz.shape[0], -1)
|
70 |
+
f_rest = get_f_rest(xyz, max_sh_degree).reshape(xyz.shape[0], -1)
|
71 |
+
opacities = get_opacity(xyz)
|
72 |
+
if scale is not None:
|
73 |
+
scale = np.log(scale)
|
74 |
+
else:
|
75 |
+
scale = get_scales(xyz)
|
76 |
+
rotation = get_rotation(normals)
|
77 |
+
|
78 |
+
# Construct list of attributes
|
79 |
+
attribute_names = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
80 |
+
attribute_names.extend([f'f_dc_{i}' for i in range(f_dc.shape[-1])])
|
81 |
+
attribute_names.extend([f'f_rest_{i}' for i in range(f_rest.shape[-1])])
|
82 |
+
attribute_names.append('opacity')
|
83 |
+
attribute_names.extend([f'scale_{i}' for i in range(scale.shape[1])])
|
84 |
+
attribute_names.extend([f'rot_{i}' for i in range(rotation.shape[1])])
|
85 |
+
|
86 |
+
# Create dtype for numpy structured array
|
87 |
+
dtype_full = [(attribute, 'f4') for attribute in attribute_names]
|
88 |
+
|
89 |
+
# Combine all attributes
|
90 |
+
attributes = np.concatenate((
|
91 |
+
xyz, normals,
|
92 |
+
f_dc,
|
93 |
+
f_rest,
|
94 |
+
opacities, scale, rotation
|
95 |
+
), axis=1)
|
96 |
+
|
97 |
+
# Ensure attributes match the dtype
|
98 |
+
assert attributes.shape[1] == len(dtype_full), f"Mismatch in attribute count. Expected {len(dtype_full)}, got {attributes.shape[1]}"
|
99 |
+
|
100 |
+
# Create structured array
|
101 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
102 |
+
elements[:] = list(map(tuple, attributes))
|
103 |
+
|
104 |
+
# Create PlyElement and save
|
105 |
+
el = PlyElement.describe(elements, 'vertex')
|
106 |
+
PlyData([el]).write(path)
|
requirements.txt
CHANGED
@@ -18,4 +18,5 @@ transformers
|
|
18 |
kornia
|
19 |
timm
|
20 |
numpy==1.26.4
|
21 |
-
open3d
|
|
|
|
18 |
kornia
|
19 |
timm
|
20 |
numpy==1.26.4
|
21 |
+
open3d
|
22 |
+
plyfile
|