jacklangerman
commited on
Commit
•
3910cfe
1
Parent(s):
2ed8e4c
first commit
Browse files
hoho.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import shutil
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Dict
|
6 |
-
|
7 |
-
from PIL import ImageFile
|
8 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
-
|
10 |
-
LOCAL_DATADIR = None
|
11 |
-
|
12 |
-
def setup(local_dir='./data/usm-training-data/data'):
|
13 |
-
|
14 |
-
# If we are in the test environment, we need to link the data directory to the correct location
|
15 |
-
tmp_datadir = Path('/tmp/data/data')
|
16 |
-
local_test_datadir = Path('./data/usm-test-data-x/data')
|
17 |
-
local_val_datadir = Path(local_dir)
|
18 |
-
|
19 |
-
os.system('pwd')
|
20 |
-
os.system('ls -lahtr .')
|
21 |
-
|
22 |
-
if tmp_datadir.exists() and not local_test_datadir.exists():
|
23 |
-
global LOCAL_DATADIR
|
24 |
-
LOCAL_DATADIR = local_test_datadir
|
25 |
-
# shutil.move(datadir, './usm-test-data-x/data')
|
26 |
-
print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
|
27 |
-
LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
|
28 |
-
LOCAL_DATADIR.symlink_to(tmp_datadir)
|
29 |
-
else:
|
30 |
-
LOCAL_DATADIR = local_val_datadir
|
31 |
-
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
32 |
-
|
33 |
-
# os.system("ls -lahtr")
|
34 |
-
|
35 |
-
assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
|
36 |
-
return LOCAL_DATADIR
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
import importlib
|
42 |
-
from pathlib import Path
|
43 |
-
import subprocess
|
44 |
-
|
45 |
-
def download_package(package_name, path_to_save='packages'):
|
46 |
-
"""
|
47 |
-
Downloads a package using pip and saves it to a specified directory.
|
48 |
-
|
49 |
-
Parameters:
|
50 |
-
package_name (str): The name of the package to download.
|
51 |
-
path_to_save (str): The path to the directory where the package will be saved.
|
52 |
-
"""
|
53 |
-
try:
|
54 |
-
# pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
|
55 |
-
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
|
56 |
-
"-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
|
57 |
-
"--platform", "manylinux1_x86_64", # Specify the platform
|
58 |
-
"--python-version", "38", # Specify the Python version
|
59 |
-
"--only-binary=:all:"]) # Download only binary packages
|
60 |
-
print(f'Package "{package_name}" downloaded successfully')
|
61 |
-
except subprocess.CalledProcessError as e:
|
62 |
-
print(f'Failed to downloaded package "{package_name}". Error: {e}')
|
63 |
-
|
64 |
-
|
65 |
-
def install_package_from_local_file(package_name, folder='packages'):
|
66 |
-
"""
|
67 |
-
Installs a package from a local .whl file or a directory containing .whl files using pip.
|
68 |
-
|
69 |
-
Parameters:
|
70 |
-
path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
|
71 |
-
"""
|
72 |
-
try:
|
73 |
-
pth = str(Path(folder) / package_name)
|
74 |
-
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
|
75 |
-
"--no-index", # Do not use package index
|
76 |
-
"--find-links", pth, # Look for packages in the specified directory or at the file
|
77 |
-
package_name]) # Specify the package to install
|
78 |
-
print(f"Package installed successfully from {pth}")
|
79 |
-
except subprocess.CalledProcessError as e:
|
80 |
-
print(f"Failed to install package from {pth}. Error: {e}")
|
81 |
-
|
82 |
-
|
83 |
-
def importt(module_name, as_name=None):
|
84 |
-
"""
|
85 |
-
Imports a module and returns it.
|
86 |
-
|
87 |
-
Parameters:
|
88 |
-
module_name (str): The name of the module to import.
|
89 |
-
as_name (str): The name to use for the imported module. If None, the original module name will be used.
|
90 |
-
|
91 |
-
Returns:
|
92 |
-
The imported module.
|
93 |
-
"""
|
94 |
-
for _ in range(2):
|
95 |
-
try:
|
96 |
-
if as_name is None:
|
97 |
-
print(f'imported {module_name}')
|
98 |
-
return importlib.import_module(module_name)
|
99 |
-
else:
|
100 |
-
print(f'imported {module_name} as {as_name}')
|
101 |
-
return importlib.import_module(module_name, as_name)
|
102 |
-
except ModuleNotFoundError as e:
|
103 |
-
install_package_from_local_file(module_name)
|
104 |
-
print(f"Failed to import module {module_name}. Error: {e}")
|
105 |
-
|
106 |
-
|
107 |
-
def prepare_submission():
|
108 |
-
# Download packages from requirements.txt
|
109 |
-
if Path('requirements.txt').exists():
|
110 |
-
print('downloading packages from requirements.txt')
|
111 |
-
Path('packages').mkdir(exist_ok=True)
|
112 |
-
with open('requirements.txt') as f:
|
113 |
-
packages = f.readlines()
|
114 |
-
for p in packages:
|
115 |
-
download_package(p.strip())
|
116 |
-
|
117 |
-
|
118 |
-
print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
|
119 |
-
|
120 |
-
|
121 |
-
def Rt_to_eye_target(im, K, R, t):
|
122 |
-
height = im.height
|
123 |
-
focal_length = K[0,0]
|
124 |
-
fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
|
125 |
-
|
126 |
-
x_axis, y_axis, z_axis = R
|
127 |
-
|
128 |
-
eye = -(R.T @ t).squeeze()
|
129 |
-
z_axis = z_axis.squeeze()
|
130 |
-
target = eye + z_axis
|
131 |
-
up = -y_axis
|
132 |
-
|
133 |
-
return eye, target, up, fov
|
134 |
-
|
135 |
-
|
136 |
-
########## general utilities ##########
|
137 |
-
import contextlib
|
138 |
-
import tempfile
|
139 |
-
from pathlib import Path
|
140 |
-
|
141 |
-
@contextlib.contextmanager
|
142 |
-
def working_directory(path):
|
143 |
-
"""Changes working directory and returns to previous on exit."""
|
144 |
-
prev_cwd = Path.cwd()
|
145 |
-
os.chdir(path)
|
146 |
-
try:
|
147 |
-
yield
|
148 |
-
finally:
|
149 |
-
os.chdir(prev_cwd)
|
150 |
-
|
151 |
-
@contextlib.contextmanager
|
152 |
-
def temp_working_directory():
|
153 |
-
with tempfile.TemporaryDirectory(dir='.') as D:
|
154 |
-
with working_directory(D):
|
155 |
-
yield
|
156 |
-
|
157 |
-
|
158 |
-
############# Dataset #############
|
159 |
-
def proc(row, split='train'):
|
160 |
-
# column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
|
161 |
-
# column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
|
162 |
-
# cols = column_names_train if split == 'train' else column_names_test
|
163 |
-
out = {}
|
164 |
-
for k, v in row.items():
|
165 |
-
colname = k.split('.')[0]
|
166 |
-
if colname in {'ade20k', 'depthcm', 'gestalt'}:
|
167 |
-
if colname in out:
|
168 |
-
out[colname].append(v)
|
169 |
-
else:
|
170 |
-
out[colname] = [v]
|
171 |
-
elif colname in {'wireframe', 'mesh'}:
|
172 |
-
# out.update({a: b.tolist() for a,b in v.items()})
|
173 |
-
out.update({a: b for a,b in v.items()})
|
174 |
-
elif colname in 'kr':
|
175 |
-
out[colname.upper()] = v
|
176 |
-
else:
|
177 |
-
out[colname] = v
|
178 |
-
|
179 |
-
return Sample(out)
|
180 |
-
|
181 |
-
|
182 |
-
class Sample(Dict):
|
183 |
-
def __repr__(self):
|
184 |
-
return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
def get_params():
|
189 |
-
exmaple_param_dict = {
|
190 |
-
"competition_id": "usm3d/S23DR",
|
191 |
-
"competition_type": "script",
|
192 |
-
"metric": "custom",
|
193 |
-
"token": "hf_**********************************",
|
194 |
-
"team_id": "local-test-team_id",
|
195 |
-
"submission_id": "local-test-submission_id",
|
196 |
-
"submission_id_col": "__key__",
|
197 |
-
"submission_cols": [
|
198 |
-
"__key__",
|
199 |
-
"wf_edges",
|
200 |
-
"wf_vertices",
|
201 |
-
"edge_semantics"
|
202 |
-
],
|
203 |
-
"submission_rows": 180,
|
204 |
-
"output_path": ".",
|
205 |
-
"submission_repo": "<THE HF MODEL ID of THIS REPO",
|
206 |
-
"time_limit": 7200,
|
207 |
-
"dataset": "usm3d/usm-test-data-x",
|
208 |
-
"submission_filenames": [
|
209 |
-
"submission.parquet"
|
210 |
-
]
|
211 |
-
}
|
212 |
-
|
213 |
-
param_path = Path('params.json')
|
214 |
-
|
215 |
-
if not param_path.exists():
|
216 |
-
print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
|
217 |
-
params = exmaple_param_dict
|
218 |
-
else:
|
219 |
-
print('found params.json (this means we are probably in the test env). Using params from file.')
|
220 |
-
with param_path.open() as f:
|
221 |
-
params = json.load(f)
|
222 |
-
print(params)
|
223 |
-
return params
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
import webdataset as wds
|
228 |
-
import numpy as np
|
229 |
-
|
230 |
-
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
|
231 |
-
if LOCAL_DATADIR is None:
|
232 |
-
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
233 |
-
|
234 |
-
local_dir = Path(LOCAL_DATADIR)
|
235 |
-
if split != 'all':
|
236 |
-
local_dir = local_dir / split
|
237 |
-
|
238 |
-
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
239 |
-
|
240 |
-
dataset = wds.WebDataset(paths)
|
241 |
-
if decode is not None:
|
242 |
-
dataset = dataset.decode(decode)
|
243 |
-
else:
|
244 |
-
dataset = dataset.decode()
|
245 |
-
|
246 |
-
dataset = dataset.map(proc)
|
247 |
-
|
248 |
-
if dataset_type == 'webdataset':
|
249 |
-
return dataset
|
250 |
-
|
251 |
-
if dataset_type == 'hf':
|
252 |
-
import datasets
|
253 |
-
from datasets import Features, Value, Sequence, Image, Array2D
|
254 |
-
|
255 |
-
if split == 'train':
|
256 |
-
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
257 |
-
elif split == 'val':
|
258 |
-
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|