colin1842 commited on
Commit
2267ec1
·
verified ·
1 Parent(s): 343e32f

Upload 9 files

Browse files
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. EDA.ipynb +3 -0
  3. README.md +31 -3
  4. app.py +239 -0
  5. data_hoho_pc2wf.py +204 -0
  6. example_on_training.ipynb +0 -0
  7. feature_solution.py +687 -0
  8. handcrafted_solution.py +245 -0
  9. requirements.txt +21 -0
  10. script.py +223 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EDA.ipynb filter=lfs diff=lfs merge=lfs -text
EDA.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:578b3d2e7384b24fe5d283054aa12e2c3ec5c32f9ded7a707af8976c22c188f4
3
+ size 14510073
README.md CHANGED
@@ -1,3 +1,31 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Handcrafted solution example for the S23DR competition
2
+
3
+ This repo provides an example of a simple algorithm to reconstruct wireframe and submit to S23DR competition.
4
+
5
+
6
+ The repo consistst of the following parts:
7
+
8
+ - `script.py` - the main file, which is run by the competition space. It should produce `submission.parquet` as the result of the run.
9
+ - `hoho.py` - the file for parsing the dataset at the inference time. Do NOT change it.
10
+ - `handcrafted_solution.py` - contains the actual implementation of the algorithm
11
+ - other `*.py` files - helper i/o and visualization utilities
12
+ - `packages/` - the directory to put python wheels for the custom packages you want to install and use.
13
+
14
+ ## Solution description
15
+
16
+ The solution is simple.
17
+
18
+ 1. Using provided (but noisy) semantic segmentation called `gestalt`, it takes the centroids of the vertex classes - `apex` and `eave_end_point` and projects them to 3D using provided (also noisy) monocular depth.
19
+ 2. The vertices are connected using the same segmentation, by checking for edges classes to be present - `['eave', 'ridge', 'rake', 'valley']`.
20
+ 3. All the "per-image" vertex predictions are merged in 3D space if their distance is less than threshold.
21
+ 4. All vertices, which have zero connections, are removed.
22
+
23
+
24
+ ## Example on the training set
25
+
26
+ See in [notebooks/example_on_training.ipynb](notebooks/example_on_training.ipynb)
27
+
28
+ ---
29
+ license: apache-2.0
30
+ ---
31
+
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import subprocess
2
+ # from pathlib import Path
3
+ # def install_package_from_local_file(package_name, folder='packages'):
4
+ # """
5
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
6
+
7
+ # Parameters:
8
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
9
+ # """
10
+ # try:
11
+ # pth = str(Path(folder) / package_name)
12
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
13
+ # "--no-index", # Do not use package index
14
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
15
+ # package_name]) # Specify the package to install
16
+ # print(f"Package installed successfully from {pth}")
17
+ # except subprocess.CalledProcessError as e:
18
+ # print(f"Failed to install package from {pth}. Error: {e}")
19
+
20
+ # install_package_from_local_file('hoho')
21
+
22
+ import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
23
+ # import subprocess
24
+ # import importlib
25
+ # from pathlib import Path
26
+ # import subprocess
27
+
28
+
29
+ # ### The function below is useful for installing additional python wheels.
30
+ # def install_package_from_local_file(package_name, folder='packages'):
31
+ # """
32
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
33
+
34
+ # Parameters:
35
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
36
+ # """
37
+ # try:
38
+ # pth = str(Path(folder) / package_name)
39
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
40
+ # "--no-index", # Do not use package index
41
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
42
+ # package_name]) # Specify the package to install
43
+ # print(f"Package installed successfully from {pth}")
44
+ # except subprocess.CalledProcessError as e:
45
+ # print(f"Failed to install package from {pth}. Error: {e}")
46
+
47
+
48
+ # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
49
+ # install_package_from_local_file('webdataset')
50
+ # install_package_from_local_file('tqdm')
51
+
52
+ import streamlit as st
53
+ import webdataset as wds
54
+ from tqdm import tqdm
55
+ from typing import Dict
56
+ import pandas as pd
57
+ from transformers import AutoTokenizer
58
+ import os
59
+ import time
60
+ import io
61
+ from PIL import Image as PImage
62
+ import numpy as np
63
+
64
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
65
+ from hoho import proc, Sample
66
+
67
+ def convert_entry_to_human_readable(entry):
68
+ out = {}
69
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
70
+ for k, v in entry.items():
71
+ if k in already_good:
72
+ out[k] = v
73
+ continue
74
+ if k == 'points3d':
75
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
76
+ if k == 'cameras':
77
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
78
+ if k == 'images':
79
+ out[k] = read_images_binary(fid=io.BytesIO(v))
80
+ if k in ['ade20k', 'gestalt']:
81
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
82
+ if k == 'depthcm':
83
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
84
+ return out
85
+
86
+ import subprocess
87
+ import sys
88
+ import os
89
+
90
+ import numpy as np
91
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
92
+ os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
93
+
94
+ def install_package_from_local_file(package_name, folder='packages'):
95
+ """
96
+ Installs a package from a local .whl file or a directory containing .whl files using pip.
97
+
98
+ Parameters:
99
+ package_name (str): The name of the package to install.
100
+ folder (str): The folder where the .whl files are located.
101
+ """
102
+ try:
103
+ pth = str(Path(folder) / package_name)
104
+ subprocess.check_call([sys.executable, "-m", "pip", "install",
105
+ "--no-index", # Do not use package index
106
+ "--find-links", pth, # Look for packages in the specified directory or at the file
107
+ package_name]) # Specify the package to install
108
+ print(f"Package installed successfully from {pth}")
109
+ except subprocess.CalledProcessError as e:
110
+ print(f"Failed to install package from {pth}. Error: {e}")
111
+
112
+ def setup_environment():
113
+ # Uninstall torch if it is already installed
114
+ # packages_to_uninstall = ['torch', 'torchvision', 'torchaudio']
115
+ # for package in packages_to_uninstall:
116
+ # uninstall_package(package)
117
+ # Download required packages
118
+ # pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
119
+ # pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121
120
+ # pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
121
+ # packages_to_download = ['torch==1.13.1', 'torchvision==0.14.1', 'torchaudio==0.13.1']
122
+ # packages_to_download = ['torch==2.1.0', 'torchvision==0.16.0', 'torchaudio==2.1.0']
123
+ # download_packages(packages_to_download, folder='packages/torch')
124
+
125
+ # Install ninja
126
+ # install_package_from_local_file('ninja', folder='packages')
127
+
128
+ # packages_to_download = ['torch==2.1.0', 'torchvision==0.16.0', 'torchaudio==2.1.0']
129
+ # download_folder = 'packages/torch'
130
+
131
+ # Download the packages
132
+ # download_packages(packages_to_download, download_folder)
133
+
134
+ # Install packages from local files
135
+ # install_package_from_local_file('torch', folder='packages')
136
+ # install_package_from_local_file('packages/torch/torchvision-0.16.0-cp38-cp38-manylinux1_x86_64.whl', folder='packages/torch')
137
+ # install_package_from_local_file('packages/torch/torchaudio-2.1.0-cp38-cp38-manylinux1_x86_64.whl', folder='packages/torch')
138
+ # install_package_from_local_file('scikit-learn', folder='packages')
139
+ # install_package_from_local_file('open3d', folder='packages')
140
+ install_package_from_local_file('easydict', folder='packages')
141
+ install_package_from_local_file('setuptools', folder='packages')
142
+ # download_packages(['scikit-learn'], folder='packages/scikit-learn')
143
+ # download_packages(['open3d'], folder='packages/open3d')
144
+ # download_packages(['easydict'], folder='packages/easydict')
145
+
146
+ pc_util_path = os.path.join(os.getcwd(), 'pc_util')
147
+ st.write(f"The path to pc_util is {pc_util_path}")
148
+ if os.path.isdir(pc_util_path):
149
+ os.chdir(pc_util_path)
150
+ st.write(f"Installing pc_util from {pc_util_path}")
151
+ subprocess.check_call([sys.executable, "setup.py", "install"])
152
+ st.write("pc_util installed successfully")
153
+ os.chdir("..")
154
+ st.write(f"Current directory is {os.getcwd()}")
155
+ else:
156
+ st.write(f"Directory {pc_util_path} does not exist")
157
+
158
+ setup_cuda_environment()
159
+
160
+ def setup_cuda_environment():
161
+ cuda_home = '/usr/local/cuda/'
162
+ if not os.path.exists(cuda_home):
163
+ raise EnvironmentError(f"CUDA_HOME directory {cuda_home} does not exist. Please install CUDA and set CUDA_HOME environment variable.")
164
+ os.environ['CUDA_HOME'] = cuda_home
165
+ os.environ['PATH'] = f"{cuda_home}/bin:{os.environ['PATH']}"
166
+ os.environ['LD_LIBRARY_PATH'] = f"{cuda_home}/lib64:{os.environ.get('LD_LIBRARY_PATH', '')}"
167
+ print(f"CUDA env setup: {cuda_home}")
168
+
169
+ from pathlib import Path
170
+ def save_submission(submission, path):
171
+ """
172
+ Saves the submission to a specified path.
173
+
174
+ Parameters:
175
+ submission (List[Dict[]]): The submission to save.
176
+ path (str): The path to save the submission to.
177
+ """
178
+ sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
179
+ sub.to_parquet(path)
180
+ print(f"Submission saved to {path}")
181
+
182
+ def main():
183
+ st.title("Hugging Face Space Prediction App")
184
+
185
+ # Setting up environment
186
+ st.write("Setting up the environment...")
187
+ # setup_environment()
188
+ try:
189
+ setup_environment()
190
+ except Exception as e:
191
+ st.error(f"Env Setup failed: {e}")
192
+ return
193
+
194
+ usr_local_contents = os.listdir('/usr/local')
195
+ # print("Items under /usr/local:")
196
+ for item in usr_local_contents:
197
+ st.write(item)
198
+
199
+ # Print CUDA path
200
+ cuda_home = os.environ.get('CUDA_HOME', 'CUDA_HOME is not set')
201
+ st.write(f"CUDA_HOME: {cuda_home}")
202
+ st.write(f"PATH: {os.environ.get('PATH', 'PATH is not set')}")
203
+ st.write(f"LD_LIBRARY_PATH: {os.environ.get('LD_LIBRARY_PATH', 'LD_LIBRARY_PATH is not set')}")
204
+
205
+ # export PATH=$PATH:/usr/local/cuda/bin
206
+ # export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64
207
+ # export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64
208
+
209
+ from handcrafted_solution import predict
210
+ st.write("Loading dataset...")
211
+
212
+ params = hoho.get_params()
213
+ dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
214
+
215
+ st.write('Running predictions...')
216
+ solution = []
217
+ from concurrent.futures import ProcessPoolExecutor
218
+ with ProcessPoolExecutor(max_workers=8) as pool:
219
+ results = []
220
+ for i, sample in enumerate(tqdm(dataset)):
221
+ results.append(pool.submit(predict, sample, visualize=False))
222
+
223
+ for i, result in enumerate(tqdm(results)):
224
+ key, pred_vertices, pred_edges = result.result()
225
+ solution.append({
226
+ '__key__': key,
227
+ 'wf_vertices': pred_vertices.tolist(),
228
+ 'wf_edges': pred_edges
229
+ })
230
+ if i % 100 == 0:
231
+ # Incrementally save the results in case we run out of time
232
+ st.write(f"Processed {i} samples")
233
+
234
+ st.write('Saving results...')
235
+ save_submission(solution, Path(params['output_path']) / "submission.parquet")
236
+ st.write("Done!")
237
+
238
+ if __name__ == "__main__":
239
+ main()
data_hoho_pc2wf.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import numpy as np
3
+ import hoho
4
+ import open3d as o3d
5
+ import copy
6
+ import trimesh
7
+ from hoho import *
8
+ from huggingface_hub import hf_hub_download
9
+ from hoho import proc
10
+ from tqdm import tqdm
11
+ import sys
12
+ sys.path.append('..')
13
+ from handcrafted_solution import *
14
+
15
+ """
16
+ dict_keys(['__key__', '__imagekey__', '__url__', 'ade20k', 'depthcm', 'gestalt',
17
+ 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces',
18
+ 'face_semantics', 'K', 'R', 't', 'images', 'points3d', 'cameras'])
19
+ """
20
+
21
+ def stat_remove_outliers(pcd_data, nb_neighbors=20, std_ratio=2.0):
22
+ """
23
+ Remove outliers from a point cloud data using Statistical Outlier Removal (SOR).
24
+
25
+ Parameters:
26
+ - pcd_data (np.array): Nx3 numpy array containing the point cloud data.
27
+ - nb_neighbors (int): Number of neighbors to analyze for each point.
28
+ - std_ratio (float): Standard deviation multiplier for distance threshold.
29
+
30
+ Returns:
31
+ - np.array: Filtered point cloud data as a Nx3 numpy array.
32
+ """
33
+ # Convert to Open3D Point Cloud format
34
+ pcd = o3d.geometry.PointCloud()
35
+ pcd.points = o3d.utility.Vector3dVector(pcd_data)
36
+
37
+ # Perform Statistical Outlier Removal
38
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=nb_neighbors, std_ratio=std_ratio)
39
+
40
+ # Extract the inlier points
41
+ inlier_cloud = pcd.select_by_index(ind)
42
+
43
+ # Convert inlier point cloud back to numpy array
44
+ inlier_pcd_data = np.asarray(inlier_cloud.points)
45
+
46
+ return inlier_pcd_data, pcd, inlier_cloud
47
+
48
+ def remove_z_outliers(pcd_data, low_threshold_percentage=50, high_threshold_percentage=0):
49
+ """
50
+ Remove outliers from a point cloud data based on z-value.
51
+
52
+ Parameters:
53
+ - pcd_data (np.array): Nx3 numpy array containing the point cloud data.
54
+ - low_threshold_percentage (float): Percentage of points to be removed based on the lowest z-values.
55
+ - high_threshold_percentage (float): Percentage of points to be removed based on the highest z-values.
56
+
57
+ Returns:
58
+ - np.array: Filtered point cloud data as a Nx3 numpy array.
59
+ """
60
+ num_std=3
61
+ low_z_threshold = np.percentile(pcd_data[:, 2], low_threshold_percentage)
62
+ high_z_threshold = np.percentile(pcd_data[:, 2], 100 - high_threshold_percentage)
63
+ mean_z, std_z = np.mean(pcd_data[:, 2]), np.std(pcd_data[:, 2])
64
+ z_range = (mean_z - num_std * std_z, mean_z + num_std * std_z)
65
+
66
+ # filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold) & (pcd_data[:, 2] < z_range[1])]
67
+ filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold)]
68
+
69
+ return filtered_pcd_data
70
+
71
+ def remove_xy_outliers(pcd_data, num_std=2):
72
+ """
73
+ Remove outliers from a point cloud data based on x and y values using a Gaussian distribution.
74
+
75
+ Parameters:
76
+ - pcd_data (np.array): Nx3 numpy array containing the point cloud data.
77
+ - num_std (float): Number of standard deviations from the mean to define the acceptable range.
78
+
79
+ Returns:
80
+ - np.array: Filtered point cloud data as a Nx3 numpy array.
81
+ """
82
+ mean_x, std_x = np.mean(pcd_data[:, 0]), np.std(pcd_data[:, 0])
83
+ mean_y, std_y = np.mean(pcd_data[:, 1]), np.std(pcd_data[:, 1])
84
+
85
+ x_range = (mean_x - num_std * std_x, mean_x + num_std * std_x)
86
+ y_range = (mean_y - num_std * std_y, mean_y + num_std * std_y)
87
+
88
+ filtered_pcd_data = pcd_data[(pcd_data[:, 0] >= x_range[0]) & (pcd_data[:, 0] <= x_range[1]) &
89
+ (pcd_data[:, 1] >= y_range[0]) & (pcd_data[:, 1] <= y_range[1])]
90
+
91
+ return filtered_pcd_data
92
+
93
+ def visualize_o3d_pcd(original_pcd, filtered_pcd):
94
+ """
95
+ Visualize the original and filtered point cloud data.
96
+
97
+ Parameters:
98
+ - original_pcd (o3d.geometry.PointCloud): The original point cloud data.
99
+ - filtered_pcd (o3d.geometry.PointCloud): The filtered point cloud data.
100
+ """
101
+ original_pcd.paint_uniform_color([1, 0, 0]) # Red color
102
+
103
+ filtered_pcd.paint_uniform_color([0, 1, 0]) # Green color
104
+
105
+ # Create a visualization window
106
+ vis = o3d.visualization.Visualizer()
107
+ vis.create_window()
108
+
109
+ vis.add_geometry(original_pcd)
110
+ vis.add_geometry(filtered_pcd)
111
+
112
+ vis.run()
113
+ vis.destroy_window()
114
+
115
+ def visualize_pcd(original_pcd_data, filtered_pcd_data):
116
+ """
117
+ Visualize the original and filtered point cloud data.
118
+
119
+ Parameters:
120
+ - original_pcd_data (np.array): The original point cloud data.
121
+ - filtered_pcd_data (np.array): The filtered point cloud data.
122
+ """
123
+ # Convert the original and filtered point cloud data to Open3D Point Cloud format
124
+ original_pcd = o3d.geometry.PointCloud()
125
+ original_pcd.points = o3d.utility.Vector3dVector(original_pcd_data)
126
+
127
+ filtered_pcd = o3d.geometry.PointCloud()
128
+ filtered_pcd.points = o3d.utility.Vector3dVector(filtered_pcd_data)
129
+
130
+ original_pcd.paint_uniform_color([1, 0, 0]) # Red color
131
+
132
+ filtered_pcd.paint_uniform_color([0, 1, 0]) # Green color
133
+
134
+ vis = o3d.visualization.Visualizer()
135
+ vis.create_window()
136
+
137
+ vis.add_geometry(original_pcd)
138
+ vis.add_geometry(filtered_pcd)
139
+
140
+ vis.run()
141
+ vis.destroy_window()
142
+
143
+ def write_wf_obj(V, E, filename='wf.obj'):
144
+ with open(filename, 'w') as f:
145
+ # Write vertices
146
+ for vertex in V:
147
+ f.write(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")
148
+
149
+ # Write edges
150
+ for edge in E:
151
+ # OBJ format is 1-indexed, Python arrays are 0-indexed, so add 1
152
+ f.write(f"l {edge[0] + 1} {edge[1] + 1}\n")
153
+
154
+ def write_pcd_xyz(xyz, filename='pcd.xyz'):
155
+ np.savetxt(filename, xyz, fmt='%.4f')
156
+
157
+ # One shard of the dataset 000-024
158
+ # scene_id = '000'
159
+
160
+ for i in range(0, 25):
161
+ scene_id = str(i).zfill(3)
162
+ print("Processing the scene: ", scene_id)
163
+
164
+ dataset = wds.WebDataset(hf_hub_download(repo_id='usm3d/hoho-train-set',
165
+ filename=f'data/train/hoho_v3_{scene_id}-of-032.tar.gz',
166
+ repo_type="dataset"))
167
+
168
+ # data_dir = Path('./data/')
169
+ # data_dir.mkdir(exist_ok=True)
170
+ # split = 'all'
171
+ # hoho.LOCAL_DATADIR = hoho.setup(data_dir)
172
+
173
+ dataset = dataset.decode()
174
+ dataset = dataset.map(proc)
175
+
176
+ os.makedirs('xyz', exist_ok=True)
177
+ os.makedirs('clean_xyz', exist_ok=True)
178
+ os.makedirs('gt', exist_ok=True)
179
+
180
+ for entry in tqdm(dataset, desc="Processing entries"):
181
+ human_entry = convert_entry_to_human_readable(entry)
182
+ key = human_entry['__key__']
183
+ cameras, images, points3D = human_entry['cameras'], human_entry['images'], human_entry['points3d']
184
+ xyz = np.stack([p.xyz for p in points3D.values()])
185
+ V, E = human_entry['wf_vertices'], human_entry['wf_edges']
186
+ u = trimesh.Trimesh(vertices=human_entry['mesh_vertices'] , faces=human_entry['mesh_faces'][:, 1:])
187
+
188
+ points, _ = trimesh.sample.sample_surface_even(u, count=10000)
189
+
190
+ # print(xyz.shape)
191
+ # print(V.shape)
192
+ # print(E.shape)
193
+ # filtered_pcd_data, original_pcd, filtered_pcd = stat_remove_outliers(xyz)
194
+ # filtered_pcd_data = remove_low_z_outliers(xyz)
195
+ filtered_pcd_data = remove_z_outliers(points, low_threshold_percentage=30, high_threshold_percentage=1.0)
196
+ # filtered_pcd_data = remove_xy_outliers(filtered_pcd_data, num_std=2)
197
+ # visualize_pcd(points, filtered_pcd_data)
198
+ # write_wf_obj(V, E, f'gt/{key}.obj')
199
+ # write_pcd_xyz(xyz, f'xyz/{key}.xyz')
200
+ # write_pcd_xyz(filtered_pcd_data, f'clean_xyz/{key}.xyz')
201
+ # print (key)
202
+ # print (entry.keys())
203
+ # break
204
+
example_on_training.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
feature_solution.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
+
3
+ import io
4
+ from PIL import Image as PImage
5
+ import numpy as np
6
+ from collections import defaultdict
7
+ import cv2
8
+ import open3d as o3d
9
+ from typing import Tuple, List
10
+ from scipy.spatial.distance import cdist
11
+
12
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
13
+ from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
14
+ import matplotlib.pyplot as plt
15
+
16
+ from kornia.feature import LoFTR
17
+ import kornia as K
18
+ import kornia.feature as KF
19
+
20
+ import torch
21
+
22
+ import copy
23
+
24
+ import matplotlib
25
+ import matplotlib.colors as mcolors
26
+ import matplotlib.pyplot as plt
27
+ import numpy as np
28
+
29
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5):
30
+ """Plot a set of images horizontally.
31
+ Args:
32
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
33
+ titles: a list of strings, as titles for each image.
34
+ cmaps: colormaps for monochrome images.
35
+ """
36
+ n = len(imgs)
37
+ if not isinstance(cmaps, (list, tuple)):
38
+ cmaps = [cmaps] * n
39
+ figsize = (size * n, size * 3 / 4) if size is not None else None
40
+ fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
41
+ if n == 1:
42
+ ax = [ax]
43
+ for i in range(n):
44
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
45
+ ax[i].get_yaxis().set_ticks([])
46
+ ax[i].get_xaxis().set_ticks([])
47
+ ax[i].set_axis_off()
48
+ for spine in ax[i].spines.values(): # remove frame
49
+ spine.set_visible(False)
50
+ if titles:
51
+ ax[i].set_title(titles[i])
52
+ fig.tight_layout(pad=pad)
53
+
54
+ def plot_lines(lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)):
55
+ """Plot lines and endpoints for existing images.
56
+ Args:
57
+ lines: list of ndarrays of size (N, 2, 2).
58
+ colors: string, or list of list of tuples (one for each keypoints).
59
+ ps: size of the keypoints as float pixels.
60
+ lw: line width as float pixels.
61
+ indices: indices of the images to draw the matches on.
62
+ """
63
+ if not isinstance(line_colors, list):
64
+ line_colors = [line_colors] * len(lines)
65
+ if not isinstance(point_colors, list):
66
+ point_colors = [point_colors] * len(lines)
67
+
68
+ fig = plt.gcf()
69
+ ax = fig.axes
70
+ assert len(ax) > max(indices)
71
+ axes = [ax[i] for i in indices]
72
+ fig.canvas.draw()
73
+
74
+ # Plot the lines and junctions
75
+ for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
76
+ for i in range(len(l)):
77
+ line = matplotlib.lines.Line2D(
78
+ (l[i, 1, 1], l[i, 0, 1]),
79
+ (l[i, 1, 0], l[i, 0, 0]),
80
+ zorder=1,
81
+ c=lc,
82
+ linewidth=lw,
83
+ )
84
+ a.add_line(line)
85
+ pts = l.reshape(-1, 2)
86
+ a.scatter(pts[:, 1], pts[:, 0], c=pc, s=ps, linewidths=0, zorder=2)
87
+
88
+ def plot_color_line_matches(lines, lw=2, indices=(0, 1)):
89
+ """Plot line matches for existing images with multiple colors.
90
+ Args:
91
+ lines: list of ndarrays of size (N, 2, 2).
92
+ lw: line width as float pixels.
93
+ indices: indices of the images to draw the matches on.
94
+ """
95
+ n_lines = len(lines[0])
96
+
97
+ cmap = plt.get_cmap("nipy_spectral", lut=n_lines)
98
+ colors = np.array([mcolors.rgb2hex(cmap(i)) for i in range(cmap.N)])
99
+
100
+ np.random.shuffle(colors)
101
+
102
+ fig = plt.gcf()
103
+ ax = fig.axes
104
+ assert len(ax) > max(indices)
105
+ axes = [ax[i] for i in indices]
106
+ fig.canvas.draw()
107
+
108
+ # Plot the lines
109
+ for a, l in zip(axes, lines):
110
+ for i in range(len(l)):
111
+ line = matplotlib.lines.Line2D(
112
+ (l[i, 1, 1], l[i, 0, 1]),
113
+ (l[i, 1, 0], l[i, 0, 0]),
114
+ zorder=1,
115
+ c=colors[i],
116
+ linewidth=lw,
117
+ )
118
+ a.add_line(line)
119
+
120
+ def empty_solution():
121
+ '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
122
+ return np.zeros((2,3)), [(0, 1)]
123
+
124
+ def convert_entry_to_human_readable(entry):
125
+ out = {}
126
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
127
+ for k, v in entry.items():
128
+ if k in already_good:
129
+ out[k] = v
130
+ continue
131
+ if k == 'points3d':
132
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
133
+ if k == 'cameras':
134
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
135
+ if k == 'images':
136
+ out[k] = read_images_binary(fid=io.BytesIO(v))
137
+ if k in ['ade20k', 'gestalt']:
138
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
139
+ if k == 'depthcm':
140
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
141
+ return out
142
+
143
+
144
+ def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 50.0):
145
+ '''Get the vertices and edges from the gestalt segmentation mask of the house'''
146
+ vertices = []
147
+ connections = []
148
+ # Apex
149
+ apex_color = np.array(gestalt_color_mapping['apex'])
150
+ apex_mask = cv2.inRange(gest_seg_np, apex_color-0.5, apex_color+0.5)
151
+ if apex_mask.sum() > 0:
152
+ output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
153
+ (numLabels, labels, stats, centroids) = output
154
+ stats, centroids = stats[1:], centroids[1:]
155
+
156
+ for i in range(numLabels-1):
157
+ vert = {"xy": centroids[i], "type": "apex"}
158
+ vertices.append(vert)
159
+
160
+ eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
161
+ eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-0.5, eave_end_color+0.5)
162
+ if eave_end_mask.sum() > 0:
163
+ output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S)
164
+ (numLabels, labels, stats, centroids) = output
165
+ stats, centroids = stats[1:], centroids[1:]
166
+
167
+ for i in range(numLabels-1):
168
+ vert = {"xy": centroids[i], "type": "eave_end_point"}
169
+ vertices.append(vert)
170
+ # Connectivity
171
+ apex_pts = []
172
+ apex_pts_idxs = []
173
+ for j, v in enumerate(vertices):
174
+ apex_pts.append(v['xy'])
175
+ apex_pts_idxs.append(j)
176
+ apex_pts = np.array(apex_pts)
177
+
178
+ # Ridge connects two apex points
179
+ for edge_class in ['eave', 'ridge', 'rake', 'valley']:
180
+ edge_color = np.array(gestalt_color_mapping[edge_class])
181
+ mask = cv2.morphologyEx(cv2.inRange(gest_seg_np,
182
+ edge_color-0.5,
183
+ edge_color+0.5),
184
+ cv2.MORPH_DILATE, np.ones((11, 11)))
185
+ line_img = np.copy(gest_seg_np) * 0
186
+ if mask.sum() > 0:
187
+ output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
188
+ (numLabels, labels, stats, centroids) = output
189
+ stats, centroids = stats[1:], centroids[1:]
190
+ edges = []
191
+ for i in range(1, numLabels):
192
+ y,x = np.where(labels == i)
193
+ xleft_idx = np.argmin(x)
194
+ x_left = x[xleft_idx]
195
+ y_left = y[xleft_idx]
196
+ xright_idx = np.argmax(x)
197
+ x_right = x[xright_idx]
198
+ y_right = y[xright_idx]
199
+ edges.append((x_left, y_left, x_right, y_right))
200
+ cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2)
201
+ edges = np.array(edges)
202
+ if (len(apex_pts) < 2) or len(edges) <1:
203
+ continue
204
+ pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:]))
205
+ connectivity_mask = pts_to_edges_dist <= edge_th
206
+ edge_connects = connectivity_mask.sum(axis=0)
207
+ for edge_idx, edgesum in enumerate(edge_connects):
208
+ if edgesum>=2:
209
+ connected_verts = np.where(connectivity_mask[:,edge_idx])[0]
210
+ for a_i, a in enumerate(connected_verts):
211
+ for b in connected_verts[a_i+1:]:
212
+ connections.append((a, b))
213
+ return vertices, connections
214
+
215
+ def get_uv_depth(vertices, depth):
216
+ '''Get the depth of the vertices from the depth image'''
217
+ uv = []
218
+ for v in vertices:
219
+ uv.append(v['xy'])
220
+ uv = np.array(uv)
221
+ uv_int = uv.astype(np.int32)
222
+ H, W = depth.shape[:2]
223
+ uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1)
224
+ uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1)
225
+ vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])]
226
+ return uv, vertex_depth
227
+
228
+ from scipy.spatial import distance_matrix
229
+ def non_maximum_suppression(points, threshold):
230
+ if len(points) == 0:
231
+ return []
232
+
233
+ # Create a distance matrix
234
+ dist_matrix = distance_matrix(points, points)
235
+
236
+ filtered_indices = []
237
+
238
+ # Suppress points within the threshold
239
+ keep = np.ones(len(points), dtype=bool)
240
+ for i in range(len(points)):
241
+ if keep[i]:
242
+ # Suppress points that are close to the current point
243
+ keep = np.logical_and(keep, dist_matrix[i] > threshold)
244
+ keep[i] = True # Keep the current point itself
245
+ filtered_indices.append(i)
246
+ return points[keep], filtered_indices
247
+
248
+ def merge_vertices_3d_ours(vert_edge_per_image, th=0.1):
249
+ '''Merge vertices that are close to each other in 3D space and are of same types'''
250
+ all_3d_vertices = []
251
+ connections_3d = []
252
+ all_indexes = []
253
+ cur_start = 0
254
+ types = []
255
+ for cimg_idx, (connections, vertices_3d) in vert_edge_per_image.items():
256
+ all_3d_vertices.append(vertices_3d)
257
+ connections = []
258
+ # connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
259
+ # cur_start+=len(vertices_3d)
260
+ all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
261
+ new_vertices, _ = non_maximum_suppression(all_3d_vertices, 75)
262
+ new_connections = []
263
+ return new_vertices, new_connections
264
+
265
+ def merge_vertices_3d(vert_edge_per_image, th=0.1):
266
+ '''Merge vertices that are close to each other in 3D space and are of same types'''
267
+ all_3d_vertices = []
268
+ connections_3d = []
269
+ all_indexes = []
270
+ cur_start = 0
271
+ types = []
272
+ for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
273
+ types += [int(v['type']=='apex') for v in vertices]
274
+ all_3d_vertices.append(vertices_3d)
275
+ connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
276
+ cur_start+=len(vertices_3d)
277
+ all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
278
+ #print (connections_3d)
279
+ distmat = cdist(all_3d_vertices, all_3d_vertices)
280
+ types = np.array(types).reshape(-1,1)
281
+ same_types = cdist(types, types)
282
+ mask_to_merge = (distmat <= th) & (same_types==0)
283
+ new_vertices = []
284
+ new_connections = []
285
+ to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
286
+ to_merge_final = defaultdict(list)
287
+ for i in range(len(all_3d_vertices)):
288
+ for j in to_merge:
289
+ if i in j:
290
+ to_merge_final[i]+=j
291
+ for k, v in to_merge_final.items():
292
+ to_merge_final[k] = list(set(v))
293
+ already_there = set()
294
+ merged = []
295
+ for k, v in to_merge_final.items():
296
+ if k in already_there:
297
+ continue
298
+ merged.append(v)
299
+ for vv in v:
300
+ already_there.add(vv)
301
+ old_idx_to_new = {}
302
+ count=0
303
+ for idxs in merged:
304
+ new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
305
+ for idx in idxs:
306
+ old_idx_to_new[idx] = count
307
+ count +=1
308
+ #print (connections_3d)
309
+ new_vertices=np.array(new_vertices)
310
+ #print (connections_3d)
311
+ for conn in connections_3d:
312
+ new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
313
+ if new_con[0] == new_con[1]:
314
+ continue
315
+ if new_con not in new_connections:
316
+ new_connections.append(new_con)
317
+ #print (f'{len(new_vertices)} left after merging {len(all_3d_vertices)} with {th=}')
318
+ return new_vertices, new_connections
319
+
320
+ def prune_not_connected(all_3d_vertices, connections_3d):
321
+ '''Prune vertices that are not connected to any other vertex'''
322
+ connected = defaultdict(list)
323
+ for c in connections_3d:
324
+ connected[c[0]].append(c)
325
+ connected[c[1]].append(c)
326
+ new_indexes = {}
327
+ new_verts = []
328
+ connected_out = []
329
+ for k,v in connected.items():
330
+ vert = all_3d_vertices[k]
331
+ if tuple(vert) not in new_verts:
332
+ new_verts.append(tuple(vert))
333
+ new_indexes[k]=len(new_verts) -1
334
+ for k,v in connected.items():
335
+ for vv in v:
336
+ connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]]))
337
+ connected_out=list(set(connected_out))
338
+
339
+ return np.array(new_verts), connected_out
340
+
341
+ def loftr_matcher(gestalt_img_0, gestalt_img1, depth_images):
342
+ import torchvision.transforms as transforms
343
+ rgb_to_gray = transforms.Compose([
344
+ transforms.ToPILImage(), # Convert tensor to PIL image
345
+ transforms.Grayscale(num_output_channels=1), # Convert to grayscale
346
+ transforms.ToTensor() # Convert back to tensor
347
+ ])
348
+
349
+ device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
350
+
351
+ w, h = depth_images.size
352
+ gest_seg_0 = gestalt_img_0.resize(depth_images.size)
353
+ gest_seg_0 = gest_seg_0.convert('L')
354
+ gest_seg_0_np = np.array(gest_seg_0)
355
+ gest_seg_0_tensor = K.image_to_tensor(gest_seg_0_np, False).float().to(device)
356
+ img1 = K.geometry.resize(gest_seg_0_tensor, (int(h/4), int(w/4))) / 255
357
+
358
+ gest_seg_1 = gestalt_img1.resize(depth_images.size)
359
+ gest_seg_1 = gest_seg_1.convert('L')
360
+ gest_seg_1_np = np.array(gest_seg_1)
361
+ gest_seg_1_tensor = K.image_to_tensor(gest_seg_1_np, False).float().to(device)
362
+ img2 = K.geometry.resize(gest_seg_1_tensor, (int(h/4), int(w/4))) / 255
363
+
364
+ matcher = KF.LoFTR(pretrained="outdoor").to(device)
365
+
366
+ input_dict = {
367
+ "image0": img1,
368
+ "image1": img2,
369
+ }
370
+ # print("Input dict shape", input_dict["image0"].shape, input_dict["image1"].shape)
371
+
372
+ with torch.no_grad():
373
+ correspondences = matcher(input_dict)
374
+
375
+ # mkpts0 = correspondences["keypoints0"].cpu().numpy()
376
+ # mkpts1 = correspondences["keypoints1"].cpu().numpy()
377
+ # Fm, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.99, 0.3, 100000)
378
+ # inliers = inliers > 0
379
+ # inliers_flat = inliers.flatten()
380
+
381
+ mkpts0 = correspondences["keypoints0"].cpu().numpy() * 4
382
+ mkpts1 = correspondences["keypoints1"].cpu().numpy() * 4
383
+
384
+ # filter out keypoints that are in [0 - W, 0.4H - H] w=1920, h=1080
385
+ heigt_th = int(0.6 * h)
386
+ filter_indices = mkpts0[:, 1] < heigt_th
387
+ mkpts0 = mkpts0[filter_indices]
388
+ mkpts1 = mkpts1[filter_indices]
389
+
390
+ return correspondences, mkpts0, mkpts1
391
+
392
+ def disk_matcher(gestalt_img_0, gestalt_img1, depth_images):
393
+ import torchvision.transforms as transforms
394
+ rgb_to_gray = transforms.Compose([
395
+ transforms.ToPILImage(), # Convert tensor to PIL image
396
+ transforms.Grayscale(num_output_channels=1), # Convert to grayscale
397
+ transforms.ToTensor() # Convert back to tensor
398
+ ])
399
+
400
+ device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')
401
+
402
+ w, h = depth_images.size
403
+ gest_seg_0 = gestalt_img_0.resize(depth_images.size)
404
+ gest_seg_0 = gest_seg_0.convert('L')
405
+ gest_seg_0_np = np.array(gest_seg_0)
406
+ gest_seg_0_tensor = K.image_to_tensor(gest_seg_0_np, False).float().to(device)
407
+ img1 = K.geometry.resize(gest_seg_0_tensor, (int(h/4), int(w/4))) / 255
408
+
409
+ gest_seg_1 = gestalt_img1.resize(depth_images.size)
410
+ gest_seg_1 = gest_seg_1.convert('L')
411
+ gest_seg_1_np = np.array(gest_seg_1)
412
+ gest_seg_1_tensor = K.image_to_tensor(gest_seg_1_np, False).float().to(device)
413
+ img2 = K.geometry.resize(gest_seg_1_tensor, (int(h/4), int(w/4))) / 255
414
+
415
+ num_features = 8192
416
+ disk = KF.DISK.from_pretrained("depth").to(device)
417
+
418
+ hw1 = torch.tensor(img1.shape[2:], device=device)
419
+ hw2 = torch.tensor(img2.shape[2:], device=device)
420
+
421
+ lg_matcher = KF.LightGlueMatcher("disk").eval().to(device)
422
+
423
+ with torch.no_grad():
424
+ inp = torch.cat([img1, img2], dim=0)
425
+ features1, features2 = disk(inp, num_features, pad_if_not_divisible=True)
426
+ kps1, descs1 = features1.keypoints, features1.descriptors
427
+ kps2, descs2 = features2.keypoints, features2.descriptors
428
+ lafs1 = KF.laf_from_center_scale_ori(kps1[None], torch.ones(1, len(kps1), 1, 1, device=device))
429
+ lafs2 = KF.laf_from_center_scale_ori(kps2[None], torch.ones(1, len(kps2), 1, 1, device=device))
430
+ dists, idxs = lg_matcher(descs1, descs2, lafs1, lafs2, hw1=hw1, hw2=hw2)
431
+ print(f"{idxs.shape[0]} tentative matches with DISK LightGlue")
432
+
433
+ lg = KF.LightGlue("disk").to(device).eval()
434
+
435
+ image0 = {
436
+ "keypoints": features1.keypoints[None],
437
+ "descriptors": features1.descriptors[None],
438
+ "image_size": torch.tensor(img1.shape[-2:][::-1]).view(1, 2).to(device),
439
+ }
440
+ image1 = {
441
+ "keypoints": features2.keypoints[None],
442
+ "descriptors": features2.descriptors[None],
443
+ "image_size": torch.tensor(img2.shape[-2:][::-1]).view(1, 2).to(device),
444
+ }
445
+
446
+ with torch.inference_mode():
447
+ out = lg({"image0": image0, "image1": image1})
448
+ idxs = out["matches"][0]
449
+ print(f"{idxs.shape[0]} tentative matches with DISK LightGlue")
450
+
451
+ def get_matching_keypoints(kp1, kp2, idxs):
452
+ mkpts1 = kp1[idxs[:, 0]]
453
+ mkpts2 = kp2[idxs[:, 1]]
454
+ return mkpts1, mkpts2
455
+
456
+ mkpts0, mkpts1 = get_matching_keypoints(kps1, kps2, idxs)
457
+
458
+ mkpts0*=4
459
+ mkpts1*=4
460
+ return mkpts0, mkpts1
461
+
462
+ def save_image_with_keypoints(filename: str, image: np.ndarray, keypoints: np.ndarray, color: Tuple[int, int, int]) -> np.ndarray:
463
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
464
+ for keypoint in keypoints:
465
+ pt = (int(keypoint[0]), int(keypoint[1]))
466
+ cv2.circle(image, pt, 4, color, -1)
467
+ # save as png
468
+ cv2.imwrite(filename, image)
469
+
470
+ ###### added for lines detection ######
471
+ def save_image_with_lines(filename: str, image: np.ndarray, lines: np.ndarray, color: Tuple[int, int, int]) -> None:
472
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
473
+ for line in lines:
474
+ pt1 = (int(line[0][1]), int(line[0][0]))
475
+ pt2 = (int(line[1][1]), int(line[1][0]))
476
+ cv2.line(image, pt1, pt2, color, 2)
477
+ cv2.imwrite(filename, image)
478
+
479
+ def line_matcher(gestalt_img_0, gestalt_img1, depth_images, line_th=0.1):
480
+ import torchvision.transforms as transforms
481
+ rgb_to_gray = transforms.Compose([
482
+ transforms.ToPILImage(), # Convert tensor to PIL image
483
+ transforms.Grayscale(num_output_channels=1), # Convert to grayscale
484
+ transforms.ToTensor() # Convert back to tensor
485
+ ])
486
+
487
+ device = 'cpu'
488
+
489
+ w, h = depth_images.size
490
+
491
+ gest_seg_0 = gestalt_img_0.resize(depth_images.size)
492
+ gest_seg_0 = gest_seg_0.convert('L')
493
+ gest_seg_0_np = np.array(gest_seg_0)
494
+ gest_seg_0_tensor = K.image_to_tensor(gest_seg_0_np, False).float().to(device)
495
+ img1 = K.geometry.resize(gest_seg_0_tensor, (int(h/4), int(w/4))) / 255
496
+
497
+ gest_seg_1 = gestalt_img1.resize(depth_images.size)
498
+ gest_seg_1 = gest_seg_1.convert('L')
499
+ gest_seg_1_np = np.array(gest_seg_1)
500
+ gest_seg_1_tensor = K.image_to_tensor(gest_seg_1_np, False).float().to(device)
501
+ img2 = K.geometry.resize(gest_seg_1_tensor, (int(h/4), int(w/4))) / 255
502
+
503
+ sold2 = KF.SOLD2(pretrained=True, config=None)
504
+
505
+ imgs = torch.cat([img1, img2], dim=0)
506
+ with torch.inference_mode():
507
+ outputs = sold2(imgs)
508
+ print(outputs.keys())
509
+
510
+ line_seg1 = outputs["line_segments"][0]
511
+ line_seg2 = outputs["line_segments"][1]
512
+ desc1 = outputs["dense_desc"][0]
513
+ desc2 = outputs["dense_desc"][1]
514
+
515
+ # print("Input dict shape", input_dict["image0"].shape, input_dict["image1"].shape)
516
+ with torch.no_grad():
517
+ matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None])
518
+
519
+ valid_matches = matches != -1
520
+ match_indices = matches[valid_matches]
521
+
522
+ matched_lines1 = line_seg1[valid_matches] * 4
523
+ matched_lines2 = line_seg2[match_indices] * 4
524
+
525
+ # filter out lines each single point is in [0 - W, 0.4H - H] w=1920, h=1080
526
+ heigt_th = int(0.6 * h)
527
+ # filter_indices = (matched_lines1[:, 0, 1] < heigt_th).all(1) & (matched_lines1[:, 0, 1] < heigt_th).all(1)
528
+ filter_indices = (matched_lines1[:, :, 0] < heigt_th).all(axis=1) & \
529
+ (matched_lines2[:, :, 0] < heigt_th).all(axis=1)
530
+ matched_lines1 = matched_lines1[filter_indices]
531
+ matched_lines2 = matched_lines2[filter_indices]
532
+
533
+ return matched_lines1, matched_lines2
534
+
535
+ # Gestalt color mapping
536
+ gestalt_color_mapping = {
537
+ 'unclassified': [215, 62, 138],
538
+ 'apex': [235, 88, 48],
539
+ 'eave_end_point': [248, 130, 228],
540
+ 'eave': [54, 243, 63],
541
+ 'ridge': [214, 251, 248],
542
+ 'rake': [13, 94, 47],
543
+ 'valley': [85, 27, 65],
544
+ 'unknown': [127, 127, 127]
545
+ }
546
+
547
+ def extract_segmented_area(image: np.ndarray, color: List[int]) -> np.ndarray:
548
+ lower = np.array(color) - 3 # 0.5
549
+ upper = np.array(color) + 3 # 0.5
550
+ mask = cv2.inRange(image, lower, upper)
551
+ return mask
552
+
553
+ def combine_masks(image: np.ndarray, color_mapping: dict) -> np.ndarray:
554
+ combined_mask = np.zeros(image.shape[:2], dtype=np.uint8)
555
+ for color in color_mapping.values():
556
+ mask = extract_segmented_area(image, color)
557
+ combined_mask = cv2.bitwise_or(combined_mask, mask)
558
+ return combined_mask
559
+
560
+ def filter_points_by_mask(points: np.ndarray, mask: np.ndarray) -> np.ndarray:
561
+ filtered_points = []
562
+ filtered_indices = []
563
+ for idx, point in enumerate(points):
564
+ y, x = int(point[1]), int(point[0])
565
+ if mask[y, x] > 0:
566
+ filtered_points.append(point)
567
+ filtered_indices.append(idx)
568
+ return np.array(filtered_points), filtered_indices
569
+
570
+ ###### added for lines detection ########
571
+
572
+ def triangulate_points(mkpts0, mkpts1, R_0, t_0, R_1, t_1, intrinsics):
573
+ P0 = intrinsics @ np.hstack((R_0, t_0.reshape(-1, 1)))
574
+ P1 = intrinsics @ np.hstack((R_1, t_1.reshape(-1, 1)))
575
+
576
+ mkpts0_h = np.vstack((mkpts0.T, np.ones((1, mkpts0.shape[0]))))
577
+ mkpts1_h = np.vstack((mkpts1.T, np.ones((1, mkpts1.shape[0]))))
578
+
579
+ points_4D_hom = cv2.triangulatePoints(P0, P1, mkpts0_h[:2], mkpts1_h[:2])
580
+ points_3D = points_4D_hom / points_4D_hom[3]
581
+
582
+ return points_3D[:3].T
583
+
584
+ def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
585
+ good_entry = convert_entry_to_human_readable(entry)
586
+ vert_edge_per_image = {}
587
+
588
+ for i, (gest, depth, K, R, t) in enumerate(zip(good_entry['gestalt'],
589
+ good_entry['depthcm'],
590
+ good_entry['K'],
591
+ good_entry['R'],
592
+ good_entry['t']
593
+ )):
594
+ # LoFTR matching keypoints
595
+ if i < 2:
596
+ j = i + 1
597
+ else:
598
+ j = 0
599
+ correspondences, mkpts0, mkpts1 = loftr_matcher(good_entry['gestalt'][i], good_entry['gestalt'][j], good_entry['depthcm'][i])
600
+ # mkpts0, mkpts1 = disk_matcher(good_entry['gestalt'][i], good_entry['gestalt'][j], good_entry['depthcm'][i])
601
+
602
+ # Added by Tang: apply mask to filter out keypoints in mkpts0
603
+ gest_seg_np = np.array(gest.resize(depth.size)).astype(np.uint8)
604
+
605
+ gest_seg_0 = np.array(good_entry['gestalt'][i].resize(depth.size)).astype(np.uint8)
606
+ gest_seg_1 = np.array(good_entry['gestalt'][j].resize(depth.size)).astype(np.uint8)
607
+
608
+ combined_mask_0 = combine_masks(gest_seg_0, gestalt_color_mapping)
609
+ combined_mask_1 = combine_masks(gest_seg_1, gestalt_color_mapping)
610
+
611
+ mkpts_filtered_0, indice_0 = filter_points_by_mask(mkpts0, combined_mask_0)
612
+ mkpts_filtered_1 = mkpts1[indice_0]
613
+
614
+ # Add NMS for 2D keypoints
615
+ mkpts_filtered_0, filtered_index = non_maximum_suppression(mkpts_filtered_0, 50)
616
+ mkpts_filtered_1 = mkpts_filtered_1[filtered_index]
617
+
618
+
619
+ save_image_with_keypoints(f'keypoints_{i}.png', np.array(good_entry['gestalt'][i]), mkpts_filtered_0, (255, 0, 0))
620
+ save_image_with_keypoints(f'keypoints_{j}.png', np.array(good_entry['gestalt'][j]), mkpts_filtered_1, (255, 0, 0))
621
+
622
+ # Line matching
623
+ line_0, line_1 = line_matcher(good_entry['gestalt'][i], good_entry['gestalt'][j], good_entry['depthcm'][i])
624
+ save_image_with_lines(f'line_{i}.png', np.array(good_entry['gestalt'][i]), line_0, (255, 0, 0))
625
+ save_image_with_lines(f'line_{j}.png', np.array(good_entry['gestalt'][j]), line_1, (255, 0, 0))
626
+
627
+
628
+ # Triangulation with matched keypoints
629
+ R_0 = good_entry['R'][i]
630
+ t_0 = good_entry['t'][i]
631
+ R_1 = good_entry['R'][j]
632
+ t_1 = good_entry['t'][j]
633
+ intrinsics = K
634
+
635
+ points_3d = triangulate_points(mkpts_filtered_0, mkpts_filtered_1, R_0, t_0, R_1, t_1, intrinsics)
636
+
637
+ gest_seg = gest.resize(depth.size)
638
+ gest_seg_np = np.array(gest_seg).astype(np.uint8)
639
+ # Metric3D
640
+ depth_np = np.array(depth) / 2.5 # 2.5 is the scale estimation coefficient
641
+ vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 5.)
642
+ if (len(vertices) < 2) or (len(connections) < 1):
643
+ print (f'Not enough vertices or connections in image {i}')
644
+ vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3))
645
+ # continue
646
+ uv, depth_vert = get_uv_depth(vertices, depth_np)
647
+
648
+ # monodepth
649
+ # r<32 scale = colmap depth / monodepth
650
+ # monodepth /= scale
651
+ # # Assuming monodepth is provided similarly as depth
652
+ # monodepth = ?
653
+ # scale = np.mean(depth_np / monodepth)
654
+ # monodepth /= scale
655
+
656
+ # Normalize the uv to the camera intrinsics
657
+ xy_local = np.ones((len(uv), 3))
658
+ xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0]
659
+ xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1]
660
+ # Get the 3D vertices
661
+ vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None])
662
+ world_to_cam = np.eye(4)
663
+ world_to_cam[:3, :3] = R
664
+ world_to_cam[:3, 3] = t.reshape(-1)
665
+ cam_to_world = np.linalg.inv(world_to_cam)
666
+ vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world)
667
+ vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
668
+ # vert_edge_per_image[i] = vertices, connections, vertices_3d
669
+
670
+ # ours method
671
+ vert_edge_per_image[i] = connections, points_3d
672
+
673
+
674
+ all_3d_vertices, connections_3d = merge_vertices_3d_ours(vert_edge_per_image, 3.0)
675
+
676
+ all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
677
+ if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
678
+ print (f'Not enough vertices or connections in the 3D vertices')
679
+ return (good_entry['__key__'], *empty_solution())
680
+ if visualize:
681
+ from hoho.viz3d import plot_estimate_and_gt
682
+ plot_estimate_and_gt( all_3d_vertices_clean,
683
+ connections_3d_clean,
684
+ good_entry['wf_vertices'],
685
+ good_entry['wf_edges'])
686
+ return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean
687
+
handcrafted_solution.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: This file contains the handcrafted solution for the task of wireframe reconstruction
2
+
3
+ import io
4
+ from PIL import Image as PImage
5
+ import numpy as np
6
+ from collections import defaultdict
7
+ import cv2
8
+ from typing import Tuple, List
9
+ from scipy.spatial.distance import cdist
10
+
11
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
12
+ from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
13
+
14
+
15
+ def empty_solution():
16
+ '''Return a minimal valid solution, i.e. 2 vertices and 1 edge.'''
17
+ return np.zeros((2,3)), [(0, 1)]
18
+
19
+
20
+ def convert_entry_to_human_readable(entry):
21
+ out = {}
22
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
23
+ for k, v in entry.items():
24
+ if k in already_good:
25
+ out[k] = v
26
+ continue
27
+ if k == 'points3d':
28
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
29
+ if k == 'cameras':
30
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
31
+ if k == 'images':
32
+ out[k] = read_images_binary(fid=io.BytesIO(v))
33
+ if k in ['ade20k', 'gestalt']:
34
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
35
+ if k == 'depthcm':
36
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
37
+ return out
38
+
39
+
40
+ def get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 50.0):
41
+ '''Get the vertices and edges from the gestalt segmentation mask of the house'''
42
+ vertices = []
43
+ connections = []
44
+ # Apex
45
+ apex_color = np.array(gestalt_color_mapping['apex'])
46
+ apex_mask = cv2.inRange(gest_seg_np, apex_color-0.5, apex_color+0.5)
47
+ if apex_mask.sum() > 0:
48
+ output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
49
+ (numLabels, labels, stats, centroids) = output
50
+ stats, centroids = stats[1:], centroids[1:]
51
+
52
+ for i in range(numLabels-1):
53
+ vert = {"xy": centroids[i], "type": "apex"}
54
+ vertices.append(vert)
55
+
56
+ eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
57
+ eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-0.5, eave_end_color+0.5)
58
+ if eave_end_mask.sum() > 0:
59
+ output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S)
60
+ (numLabels, labels, stats, centroids) = output
61
+ stats, centroids = stats[1:], centroids[1:]
62
+
63
+ for i in range(numLabels-1):
64
+ vert = {"xy": centroids[i], "type": "eave_end_point"}
65
+ vertices.append(vert)
66
+ # Connectivity
67
+ apex_pts = []
68
+ apex_pts_idxs = []
69
+ for j, v in enumerate(vertices):
70
+ apex_pts.append(v['xy'])
71
+ apex_pts_idxs.append(j)
72
+ apex_pts = np.array(apex_pts)
73
+
74
+ # Ridge connects two apex points
75
+ for edge_class in ['eave', 'ridge', 'rake', 'valley']:
76
+ edge_color = np.array(gestalt_color_mapping[edge_class])
77
+ mask = cv2.morphologyEx(cv2.inRange(gest_seg_np,
78
+ edge_color-0.5,
79
+ edge_color+0.5),
80
+ cv2.MORPH_DILATE, np.ones((11, 11)))
81
+ line_img = np.copy(gest_seg_np) * 0
82
+ if mask.sum() > 0:
83
+ output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
84
+ (numLabels, labels, stats, centroids) = output
85
+ stats, centroids = stats[1:], centroids[1:]
86
+ edges = []
87
+ for i in range(1, numLabels):
88
+ y,x = np.where(labels == i)
89
+ xleft_idx = np.argmin(x)
90
+ x_left = x[xleft_idx]
91
+ y_left = y[xleft_idx]
92
+ xright_idx = np.argmax(x)
93
+ x_right = x[xright_idx]
94
+ y_right = y[xright_idx]
95
+ edges.append((x_left, y_left, x_right, y_right))
96
+ cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2)
97
+ edges = np.array(edges)
98
+ if (len(apex_pts) < 2) or len(edges) <1:
99
+ continue
100
+ pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:]))
101
+ connectivity_mask = pts_to_edges_dist <= edge_th
102
+ edge_connects = connectivity_mask.sum(axis=0)
103
+ for edge_idx, edgesum in enumerate(edge_connects):
104
+ if edgesum>=2:
105
+ connected_verts = np.where(connectivity_mask[:,edge_idx])[0]
106
+ for a_i, a in enumerate(connected_verts):
107
+ for b in connected_verts[a_i+1:]:
108
+ connections.append((a, b))
109
+ return vertices, connections
110
+
111
+ def get_uv_depth(vertices, depth):
112
+ '''Get the depth of the vertices from the depth image'''
113
+ uv = []
114
+ for v in vertices:
115
+ uv.append(v['xy'])
116
+ uv = np.array(uv)
117
+ uv_int = uv.astype(np.int32)
118
+ H, W = depth.shape[:2]
119
+ uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1)
120
+ uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1)
121
+ vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])]
122
+ return uv, vertex_depth
123
+
124
+
125
+ def merge_vertices_3d(vert_edge_per_image, th=0.1):
126
+ '''Merge vertices that are close to each other in 3D space and are of same types'''
127
+ all_3d_vertices = []
128
+ connections_3d = []
129
+ all_indexes = []
130
+ cur_start = 0
131
+ types = []
132
+ for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
133
+ types += [int(v['type']=='apex') for v in vertices]
134
+ all_3d_vertices.append(vertices_3d)
135
+ connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
136
+ cur_start+=len(vertices_3d)
137
+ all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
138
+ #print (connections_3d)
139
+ distmat = cdist(all_3d_vertices, all_3d_vertices)
140
+ types = np.array(types).reshape(-1,1)
141
+ same_types = cdist(types, types)
142
+ mask_to_merge = (distmat <= th) & (same_types==0)
143
+ new_vertices = []
144
+ new_connections = []
145
+ to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
146
+ to_merge_final = defaultdict(list)
147
+ for i in range(len(all_3d_vertices)):
148
+ for j in to_merge:
149
+ if i in j:
150
+ to_merge_final[i]+=j
151
+ for k, v in to_merge_final.items():
152
+ to_merge_final[k] = list(set(v))
153
+ already_there = set()
154
+ merged = []
155
+ for k, v in to_merge_final.items():
156
+ if k in already_there:
157
+ continue
158
+ merged.append(v)
159
+ for vv in v:
160
+ already_there.add(vv)
161
+ old_idx_to_new = {}
162
+ count=0
163
+ for idxs in merged:
164
+ new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
165
+ for idx in idxs:
166
+ old_idx_to_new[idx] = count
167
+ count +=1
168
+ #print (connections_3d)
169
+ new_vertices=np.array(new_vertices)
170
+ #print (connections_3d)
171
+ for conn in connections_3d:
172
+ new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
173
+ if new_con[0] == new_con[1]:
174
+ continue
175
+ if new_con not in new_connections:
176
+ new_connections.append(new_con)
177
+ #print (f'{len(new_vertices)} left after merging {len(all_3d_vertices)} with {th=}')
178
+ return new_vertices, new_connections
179
+
180
+ def prune_not_connected(all_3d_vertices, connections_3d):
181
+ '''Prune vertices that are not connected to any other vertex'''
182
+ connected = defaultdict(list)
183
+ for c in connections_3d:
184
+ connected[c[0]].append(c)
185
+ connected[c[1]].append(c)
186
+ new_indexes = {}
187
+ new_verts = []
188
+ connected_out = []
189
+ for k,v in connected.items():
190
+ vert = all_3d_vertices[k]
191
+ if tuple(vert) not in new_verts:
192
+ new_verts.append(tuple(vert))
193
+ new_indexes[k]=len(new_verts) -1
194
+ for k,v in connected.items():
195
+ for vv in v:
196
+ connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]]))
197
+ connected_out=list(set(connected_out))
198
+
199
+ return np.array(new_verts), connected_out
200
+
201
+
202
+ def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]:
203
+ good_entry = convert_entry_to_human_readable(entry)
204
+ vert_edge_per_image = {}
205
+ for i, (gest, depth, K, R, t) in enumerate(zip(good_entry['gestalt'],
206
+ good_entry['depthcm'],
207
+ good_entry['K'],
208
+ good_entry['R'],
209
+ good_entry['t']
210
+ )):
211
+ gest_seg = gest.resize(depth.size)
212
+ gest_seg_np = np.array(gest_seg).astype(np.uint8)
213
+ # Metric3D
214
+ depth_np = np.array(depth) / 2.5 # 2.5 is the scale estimation coefficient
215
+ vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 5.)
216
+ if (len(vertices) < 2) or (len(connections) < 1):
217
+ print (f'Not enough vertices or connections in image {i}')
218
+ vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3))
219
+ continue
220
+ uv, depth_vert = get_uv_depth(vertices, depth_np)
221
+ # Normalize the uv to the camera intrinsics
222
+ xy_local = np.ones((len(uv), 3))
223
+ xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0]
224
+ xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1]
225
+ # Get the 3D vertices
226
+ vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None])
227
+ world_to_cam = np.eye(4)
228
+ world_to_cam[:3, :3] = R
229
+ world_to_cam[:3, 3] = t.reshape(-1)
230
+ cam_to_world = np.linalg.inv(world_to_cam)
231
+ vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world)
232
+ vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
233
+ vert_edge_per_image[i] = vertices, connections, vertices_3d
234
+ all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 3.0)
235
+ all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
236
+ if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
237
+ print (f'Not enough vertices or connections in the 3D vertices')
238
+ return (good_entry['__key__'], *empty_solution())
239
+ if visualize:
240
+ from hoho.viz3d import plot_estimate_and_gt
241
+ plot_estimate_and_gt( all_3d_vertices_clean,
242
+ connections_3d_clean,
243
+ good_entry['wf_vertices'],
244
+ good_entry['wf_edges'])
245
+ return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ webdataset
2
+ opencv-python
3
+ torchvision
4
+ pycolmap
5
+ torch
6
+ kornia>=0.7.1
7
+ matplotlib
8
+ Pillow
9
+ scipy
10
+ plotly
11
+ timm
12
+ open3d
13
+ plyfile
14
+ shapely
15
+ scikit-spatial
16
+ scikit-learn
17
+ numpy
18
+ git+https://hf.co/usm3d/tools.git
19
+ trimesh
20
+ ninja
21
+ transformers
script.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### This is example of the script that will be run in the test environment.
2
+ ### Some parts of the code are compulsory and you should NOT CHANGE THEM.
3
+ ### They are between '''---compulsory---''' comments.
4
+ ### You can change the rest of the code to define and test your solution.
5
+ ### However, you should not change the signature of the provided function.
6
+ ### The script would save "submission.parquet" file in the current directory.
7
+ ### The actual logic of the solution is implemented in the `handcrafted_solution.py` file.
8
+ ### The `handcrafted_solution.py` file is a placeholder for your solution.
9
+ ### You should implement the logic of your solution in that file.
10
+ ### You can use any additional files and subdirectories to organize your code.
11
+
12
+ '''---compulsory---'''
13
+ # import subprocess
14
+ # from pathlib import Path
15
+ # def install_package_from_local_file(package_name, folder='packages'):
16
+ # """
17
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
18
+
19
+ # Parameters:
20
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
21
+ # """
22
+ # try:
23
+ # pth = str(Path(folder) / package_name)
24
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
25
+ # "--no-index", # Do not use package index
26
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
27
+ # package_name]) # Specify the package to install
28
+ # print(f"Package installed successfully from {pth}")
29
+ # except subprocess.CalledProcessError as e:
30
+ # print(f"Failed to install package from {pth}. Error: {e}")
31
+
32
+ # install_package_from_local_file('hoho')
33
+
34
+ import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
35
+ # import subprocess
36
+ # import importlib
37
+ # from pathlib import Path
38
+ # import subprocess
39
+
40
+
41
+ # ### The function below is useful for installing additional python wheels.
42
+ # def install_package_from_local_file(package_name, folder='packages'):
43
+ # """
44
+ # Installs a package from a local .whl file or a directory containing .whl files using pip.
45
+
46
+ # Parameters:
47
+ # path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
48
+ # """
49
+ # try:
50
+ # pth = str(Path(folder) / package_name)
51
+ # subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
52
+ # "--no-index", # Do not use package index
53
+ # "--find-links", pth, # Look for packages in the specified directory or at the file
54
+ # package_name]) # Specify the package to install
55
+ # print(f"Package installed successfully from {pth}")
56
+ # except subprocess.CalledProcessError as e:
57
+ # print(f"Failed to install package from {pth}. Error: {e}")
58
+
59
+
60
+ # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
61
+ # install_package_from_local_file('webdataset')
62
+ # install_package_from_local_file('tqdm')
63
+
64
+ ### Here you can import any library or module you want.
65
+ ### The code below is used to read and parse the input dataset.
66
+ ### Please, do not modify it.
67
+
68
+ import webdataset as wds
69
+ from tqdm import tqdm
70
+ from typing import Dict
71
+ import pandas as pd
72
+ from transformers import AutoTokenizer
73
+ import os
74
+ import time
75
+ import io
76
+ from PIL import Image as PImage
77
+ import numpy as np
78
+
79
+ from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
80
+ from hoho import proc, Sample
81
+
82
+ def convert_entry_to_human_readable(entry):
83
+ out = {}
84
+ already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
85
+ for k, v in entry.items():
86
+ if k in already_good:
87
+ out[k] = v
88
+ continue
89
+ if k == 'points3d':
90
+ out[k] = read_points3D_binary(fid=io.BytesIO(v))
91
+ if k == 'cameras':
92
+ out[k] = read_cameras_binary(fid=io.BytesIO(v))
93
+ if k == 'images':
94
+ out[k] = read_images_binary(fid=io.BytesIO(v))
95
+ if k in ['ade20k', 'gestalt']:
96
+ out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
97
+ if k == 'depthcm':
98
+ out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
99
+ return out
100
+
101
+ '''---end of compulsory---'''
102
+
103
+ ### The part below is used to define and test your solution.
104
+ import subprocess
105
+ import sys
106
+ import os
107
+
108
+ import numpy as np
109
+ os.environ['MKL_THREADING_LAYER'] = 'GNU'
110
+ os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
111
+
112
+ def uninstall_package(package_name):
113
+ """
114
+ Uninstalls a package using pip.
115
+
116
+ Parameters:
117
+ package_name (str): The name of the package to uninstall.
118
+ """
119
+ try:
120
+ subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package_name])
121
+ print(f"Package {package_name} uninstalled successfully")
122
+ except subprocess.CalledProcessError as e:
123
+ print(f"Failed to uninstall package {package_name}. Error: {e}")
124
+
125
+ def download_packages(packages, folder):
126
+ # Create the directory if it doesn't exist
127
+ if not os.path.exists(folder):
128
+ os.makedirs(folder)
129
+
130
+ try:
131
+ subprocess.check_call([
132
+ 'pip', 'download',
133
+ '--dest', folder,
134
+ '-f', 'https://download.pytorch.org/whl/cu121'
135
+ ] + packages)
136
+ print(f"Packages downloaded successfully to {folder}")
137
+ except subprocess.CalledProcessError as e:
138
+ print(f"Failed to download packages. Error: {e}")
139
+
140
+ def install_package_from_local_file(package_name, folder='packages'):
141
+ """
142
+ Installs a package from a local .whl file or a directory containing .whl files using pip.
143
+
144
+ Parameters:
145
+ package_name (str): The name of the package to install.
146
+ folder (str): The folder where the .whl files are located.
147
+ """
148
+ try:
149
+ pth = str(Path(folder) / package_name)
150
+ subprocess.check_call([sys.executable, "-m", "pip", "install",
151
+ "--no-index", # Do not use package index
152
+ "--find-links", pth, # Look for packages in the specified directory or at the file
153
+ package_name]) # Specify the package to install
154
+ print(f"Package installed successfully from {pth}")
155
+ except subprocess.CalledProcessError as e:
156
+ print(f"Failed to install package from {pth}. Error: {e}")
157
+
158
+ def install_which():
159
+ try:
160
+ # Attempt to install which if it's not available
161
+ subprocess.check_call(['sudo', 'apt-get', 'install', '-y', 'which'])
162
+ print("Which installed successfully.")
163
+ except subprocess.CalledProcessError as e:
164
+ print(f"An error occurred while installing which: {e}")
165
+ sys.exit(1)
166
+
167
+ def setup_environment():
168
+ pc_util_path = os.path.join(os.getcwd(), 'pc_util')
169
+ if os.path.isdir(pc_util_path):
170
+ os.chdir(pc_util_path)
171
+ subprocess.check_call([sys.executable, "setup.py", "install"], cwd=pc_util_path)
172
+ os.chdir("..")
173
+
174
+ from pathlib import Path
175
+ def save_submission(submission, path):
176
+ """
177
+ Saves the submission to a specified path.
178
+
179
+ Parameters:
180
+ submission (List[Dict[]]): The submission to save.
181
+ path (str): The path to save the submission to.
182
+ """
183
+ sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
184
+ sub.to_parquet(path)
185
+ print(f"Submission saved to {path}")
186
+
187
+ if __name__ == "__main__":
188
+ from feature_solution import predict
189
+ print ("------------ Loading dataset------------ ")
190
+ params = hoho.get_params()
191
+ dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
192
+
193
+ print('------------ Now you can do your solution ---------------')
194
+ solution = []
195
+ # from concurrent.futures import ProcessPoolExecutor
196
+ # with ProcessPoolExecutor(max_workers=1) as pool:
197
+ # results = []
198
+ # for i, sample in enumerate(tqdm(dataset)):
199
+ # results.append(pool.submit(predict, sample, visualize=False))
200
+
201
+ # for i, result in enumerate(tqdm(results)):
202
+ # key, pred_vertices, pred_edges = result.result()
203
+ # solution.append({
204
+ # '__key__': key,
205
+ # 'wf_vertices': pred_vertices.tolist(),
206
+ # 'wf_edges': pred_edges
207
+ # })
208
+ ####### added for removing multiprocessing ########
209
+ for i, sample in enumerate(tqdm(dataset)):
210
+ key, pred_vertices, pred_edges = predict(sample, visualize=False)
211
+ solution.append({
212
+ '__key__': key,
213
+ 'wf_vertices': pred_vertices.tolist(),
214
+ 'wf_edges': pred_edges
215
+ })
216
+ ####### added for removing multiprocessing ########
217
+ if i % 100 == 0:
218
+ # incrementally save the results in case we run out of time
219
+ print(f"Processed {i} samples")
220
+ # save_submission(solution, Path(params['output_path']) / "submission.parquet")
221
+ print('------------ Saving results ---------------')
222
+ save_submission(solution, Path(params['output_path']) / "submission.parquet")
223
+ print("------------ Done ------------ ")