handcrafted_baseline_submission / handcrafted_solution.py
kcml's picture
Upload handcrafted_solution.py
cfce8c8 verified
# Description: This file contains the handcrafted solution for the task of wireframe reconstruction
import io
from PIL import Image as PImage
import numpy as np
from collections import defaultdict
import cv2
from typing import Tuple, List
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN, OPTICS
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping
DUMP_IMG = False
if DUMP_IMG:
from scipy.sparse import random
def empty_solution():
'''Return a minimal valid solution, i.e. 2 vertices and 0 edge.'''
return np.zeros((2,3)), []
def convert_entry_to_human_readable(entry):
out = {}
already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
for k, v in entry.items():
if k in already_good:
out[k] = v
continue
if k == 'points3d':
out[k] = read_points3D_binary(fid=io.BytesIO(v))
if k == 'cameras':
out[k] = read_cameras_binary(fid=io.BytesIO(v))
if k == 'images':
out[k] = read_images_binary(fid=io.BytesIO(v))
if k in ['ade20k', 'gestalt']:
out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
if k == 'depthcm':
out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
return out
def get_uv_depth(vertices, depth):
'''Get the depth of the vertices from the depth image'''
uv = []
for v in vertices:
uv.append(v['xy'])
uv = np.array(uv)
uv_int = uv.astype(np.int32)
H, W = depth.shape[:2]
uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1)
uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1)
vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])]
return uv, vertex_depth
def get_smooth_uv_depth(vertices, depth, gest_seg_np, sfm_depth_np, r=5):
'''Get the depth of the vertices from the depth image'''
uv = []
for v in vertices:
uv.append(v['xy'])
uv = np.array(uv)
uv_int = uv.astype(np.int32)
H, W = depth.shape[:2]
a = np.clip( uv_int[:, 0], 0, W-1)
b = np.clip( uv_int[:, 1], 0, H-1)
def get_local_depth(x,y, H, W, depth, r=r):
'''return a smooth version of detph in radius r'''
local_depths = []
for i in range(max(0, x - r), min(W, x + r)):
for j in range(max(0, y - r), min(H, y + r)):
if np.sqrt((i - x)**2 + (j - y)**2) <= r:
if sfm_depth_np is not None:
if sfm_depth_np[j, i] != 0:
local_depths.append(sfm_depth_np[j, i])
else:
local_depths.append(depth[j, i])
else:
local_depths.append(depth[j, i])
return local_depths
def get_local_min(x,y, H, W, depth, sfm_depth_np, r=r, PRINT=False):
'''return a smooth version of detph in radius r'''
local_min = 9999999
i_range = range(max(0, x - r), min(W, x + r))
j_range = range(max(0, y - r), min(H, y + r))
for i in i_range:
for j in j_range:
if sfm_depth_np is not None:
if sfm_depth_np[j, i] != 0:
local_min = min(sfm_depth_np[j, i], local_min)
if PRINT: print(f'({j},{i})sfm:', sfm_depth_np[j, i])
else:
local_min = min(depth[j, i], local_min)
else:
local_min = min(depth[j, i], local_min)
return local_min
def get_priotity_local_min(x,y, H, W, depth, sfm_depth_np, r=r):
'''
Search on sfm depth first. Search on depthmap only if no sfm depth
exists at all in the local region.
'''
PRINT = False
r_choices = [5, 10, 20, 40, 75, 200]
for r in r_choices:
yslice = slice(max(0, y - r), min(H, y + r))
xslice = slice(max(0, x - r), min(W, x + r))
local_area = sfm_depth_np[yslice, xslice]
reduced_local_area = local_area[local_area!=0]
if reduced_local_area.size > 0:
break
if reduced_local_area.size > 0:
#print('use sfm')
if PRINT: print(reduced_local_area)
local_min = np.min(reduced_local_area)
return local_min
else:
#print('use both sfm and monocular')
return get_local_min(x,y, H, W, depth, sfm_depth_np, r, PRINT)
def get_local_min_progressive(x,y, H, W, depth, sfm_depth_np, r=r):
'''
If sfm is available in small local region, use it.
Otherwise, search in large region with combined depth
'''
small_r, large_r = 5, 75
PRINT= False
r = small_r
yslice = slice(max(0, y - r), min(H, y + r))
xslice = slice(max(0, x - r), min(W, x + r))
if np.any(sfm_depth_np[yslice, xslice] != 0):
return get_local_min(x,y, H, W, depth, sfm_depth_np, r)
else:
r = large_r
local_min = 9999999
i_range = range(max(0, x - r), min(W, x + r))
j_range = range(max(0, y - r), min(H, y + r))
for i in i_range:
for j in j_range:
if sfm_depth_np is not None:
if sfm_depth_np[j, i] != 0:
local_min = min(sfm_depth_np[j, i], local_min)
if PRINT: print(sfm_depth_np[j, i])
else:
local_min = min(depth[j, i], local_min)
if PRINT: print('dm:', depth[j, i])
else:
local_min = min(depth[j, i], local_min)
if PRINT: print('dm:', depth[j, i])
return local_min
vertex_depth = []
for x,y in zip(a,b):
local_min = get_priotity_local_min(x,y, H, W, depth, sfm_depth_np, r)
vertex_depth.append(local_min)
'''
local_depths = get_local_depth(x,y, H, W, depth, 5)
#local_mean = np.mean(local_depths)
local_mean = np.min(local_depths)
vertex_depth.append(local_mean)
'''
vertex_depth = np.array(vertex_depth)
return uv, vertex_depth
''' Turn on this to speed up if you have numba
from numba import njit, prange
@njit(parallel=True)
def fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W):
for i in prange(max(0, u - dilate_r), min(W, u + dilate_r)):
for j in prange(max(0, v - dilate_r), min(H, v + dilate_r)) :
#checked+=1
existing_z = sfm_depth_np[j, i]
if z > 0:
if (existing_z!=0 and z < existing_z) or (existing_z==0):
sfm_depth_np[j, i] = z
if DUMP_IMG:
sfm_color_np[j, i] = c
return sfm_depth_np, sfm_color_np
'''
def get_SfM_depth(XYZ, rgb, depth_np, gest_seg_np, K, R, t, dilate_r = 5):
'''Project 3D sfm pointcloud to the image plane '''
H, W = depth_np.shape[:2]
sfm_depth_np = np.zeros(depth_np.shape)
sfm_color_np = np.zeros(gest_seg_np.shape)
XYZ1 = np.concatenate((XYZ, np.ones((len(XYZ), 1))), axis=1)
Rt = np.concatenate( (R, t.reshape((3,1))), axis=1)
world_to_cam = K @ Rt
xyz = world_to_cam @ XYZ1.transpose()
xyz = np.transpose(xyz)
valid_idx = ~np.isclose(xyz[:,2], 0, atol=1e-2) & ~np.isnan(xyz[:,0]) & ~np.isnan(xyz[:,1]) & ~np.isnan(xyz[:,2])
xyz = xyz[valid_idx, :]
us, vs, zs = xyz[:,0]/xyz[:,2], xyz[:,1]/xyz[:,2], xyz[:,2]
us = us[~np.isnan(us)]
vs = vs[~np.isnan(vs)]
us = us.astype(np.int32)
vs = vs.astype(np.int32)
for u,v,z,c in zip(us,vs,zs, rgb):
''' Use this insead if you have numba
sfm_depth_np, sfm_color_np = fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W)
'''
i_range = range(max(0, u - dilate_r), min(W, u + dilate_r))
j_range = range(max(0, v - dilate_r), min(H, v + dilate_r))
for i in i_range:
for j in j_range:
#checked+=1
existing_z = sfm_depth_np[j, i]
if z > 0:
if (existing_z!=0 and z < existing_z) or (existing_z==0):
sfm_depth_np[j, i] = z
if DUMP_IMG:
sfm_color_np[j, i] = c
if DUMP_IMG:
filename_sfm_depth = 'sfm_depth.png'
cv2.imwrite(filename_sfm_depth, sfm_depth_np/100)
filename_sfm_color = 'sfm_color.png'
cv2.imwrite(filename_sfm_color, sfm_color_np)
filename_ref_depth = 'ref_depth.png'
cv2.imwrite(filename_ref_depth, depth_np/100)
return sfm_depth_np
def get_vertices_and_edges_from_two_segmentations(ade_seg_np, gest_seg_np, edge_th = 50.0):
'''Get the vertices and edges from the gestalt segmentation mask of the house'''
vertices = []
connections = []
color_th = 10.0
#-------------------------
# combined map from ade
if DUMP_IMG:
ade_color0 = np.array([0,0,0])
ade_mask0 = cv2.inRange(ade_seg_np, ade_color0-0.5, ade_color0+0.5)
ade_color1 = np.array([120,120,120])
ade_mask1 = cv2.inRange(ade_seg_np, ade_color1-0.5, ade_color1+0.5)
ade_color2 = np.array([180,120,120])
ade_mask2 = cv2.inRange(ade_seg_np, ade_color2-0.5, ade_color2+0.5)
ade_color3 = np.array([255,9,224])
ade_mask3 = cv2.inRange(ade_seg_np, ade_color3-0.5, ade_color3+0.5)
ade_mask = cv2.bitwise_or(ade_mask3, ade_mask2)
ade_mask = cv2.bitwise_or(ade_mask1, ade_mask)
apex_map = np.zeros(ade_seg_np.shape)
apex_map_on_ade = ade_seg_np
apex_map_on_gest = gest_seg_np
# Apex
apex_color = np.array(gestalt_color_mapping['apex'])
apex_mask = cv2.inRange(gest_seg_np, apex_color-color_th, apex_color+color_th) # include more pts
#apex_mask = cv2.bitwise_and(apex_mask, ade_mask) # remove pts
if apex_mask.sum() > 0:
output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
stats, centroids = stats[1:], centroids[1:]
for i in range(numLabels-1):
vert = {"xy": centroids[i], "type": "apex"}
vertices.append(vert)
if DUMP_IMG:
uu = int(centroids[i][1])
vv = int(centroids[i][0])
# plot a cross
apex_map_on_ade[uu, vv] = (255,255,255)
shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)]
h,w,_ = apex_map_on_ade.shape
for ss in shift:
if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w:
apex_map[uu+ss[0], vv+ss[1]] = (255,255,255)
apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,255,255)
apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,255,255)
eave_end_color = np.array(gestalt_color_mapping['eave_end_point'])
eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-color_th, eave_end_color+color_th)
if eave_end_mask.sum() > 0:
output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
stats, centroids = stats[1:], centroids[1:]
for i in range(numLabels-1):
vert = {"xy": centroids[i], "type": "eave_end_point"}
vertices.append(vert)
if DUMP_IMG:
uu = int(centroids[i][1])
vv = int(centroids[i][0])
# plot a cross
apex_map_on_ade[uu, vv] = (255,0,0)
shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)]
h,w,_ = apex_map_on_ade.shape
for ss in shift:
if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w:
apex_map[uu+ss[0], vv+ss[1]] = (255,0,0)
apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,0,0)
apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,0,0)
flashing_end_color = np.array(gestalt_color_mapping['flashing_end_point'])
flashing_end_mask = cv2.inRange(gest_seg_np, flashing_end_color-color_th/2, flashing_end_color+color_th/2) # this color is sensitive
if flashing_end_color.sum() > 0:
output = cv2.connectedComponentsWithStats(flashing_end_mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
stats, centroids = stats[1:], centroids[1:]
for i in range(numLabels-1):
vert = {"xy": centroids[i], "type": "flashing_end_point"}
vertices.append(vert)
if DUMP_IMG:
uu = int(centroids[i][1])
vv = int(centroids[i][0])
# plot a cross
apex_map_on_ade[uu, vv] = (255,0,0)
shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)]
h,w,_ = apex_map_on_ade.shape
for ss in shift:
if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w:
apex_map[uu+ss[0], vv+ss[1]] = (255,0,0)
apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,0,0)
apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,0,0)
''''''
# imsave apex and eave_end
if DUMP_IMG:
import random
rid = random.random()
filename_apex_ade = f'apex_map_on_ade_{rid}.jpg'
cv2.imwrite(filename_apex_ade, apex_map_on_ade)
filename_apex_gest = f'apex_map_on_gest_{rid}.jpg'
cv2.imwrite(filename_apex_gest, apex_map_on_gest)
filename_apex_map = f'apex_map_{rid}.jpg'
cv2.imwrite(filename_apex_map, apex_map)
# Connectivity
apex_pts = []
apex_pts_idxs = []
for j, v in enumerate(vertices):
apex_pts.append(v['xy'])
apex_pts_idxs.append(j)
apex_pts = np.array(apex_pts)
# Turns out connection is not a priority
'''
# Ridge connects two apex points
def Ridge_connects_two_apex_points(gest_seg_np, color_th, apex_pts, edge_th):
conn = []
line_img = np.copy(gest_seg_np) * 0
for edge_class in ['eave', 'ridge', 'rake', 'valley']:
edge_color = np.array(gestalt_color_mapping[edge_class])
mask = cv2.morphologyEx(cv2.inRange(gest_seg_np,
edge_color-color_th,
edge_color+color_th),
cv2.MORPH_DILATE, np.ones((11, 11)))
#line_img = np.copy(gest_seg_np) * 0
if mask.sum() > 0:
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
stats, centroids = stats[1:], centroids[1:]
edges = []
for i in range(1, numLabels):
y,x = np.where(labels == i)
xleft_idx = np.argmin(x)
x_left = x[xleft_idx]
y_left = y[xleft_idx]
xright_idx = np.argmax(x)
x_right = x[xright_idx]
y_right = y[xright_idx]
edges.append((x_left, y_left, x_right, y_right))
cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2)
edges = np.array(edges)
if (len(apex_pts) < 2) or len(edges) <1:
continue
pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:]))
connectivity_mask = pts_to_edges_dist <= edge_th
edge_connects = connectivity_mask.sum(axis=0)
for edge_idx, edgesum in enumerate(edge_connects):
if edgesum>=2:
connected_verts = np.where(connectivity_mask[:,edge_idx])[0]
for a_i, a in enumerate(connected_verts):
for b in connected_verts[a_i+1:]:
conn.append((a, b))
return conn, line_img
connections, line_img = Ridge_connects_two_apex_points(gest_seg_np, color_th, apex_pts, edge_th)
'''
'''
def classifyPairs(apex_pts, apex_pts_idxs, gest_seg_np, apex_mask, eave_end_mask):
conn = []
# Plot all possible connection pixels in one mask
mask = cv2.bitwise_or(apex_mask, eave_end_mask)
#for edge_class in ['eave', 'ridge', 'rake', 'valley', 'step_flashing' ]:#, 'flashing']:
for edge_class in ['eave', 'ridge', 'rake', 'valley', 'step_flashing' , 'flashing']:
edge_color = np.array(gestalt_color_mapping[edge_class])
mask_e = cv2.morphologyEx(cv2.inRange(gest_seg_np,
edge_color-color_th,
edge_color+color_th),
cv2.MORPH_DILATE, np.ones((11, 11)))
mask = cv2.bitwise_or(mask, mask_e)
# try connecting each apir and see if the cost on the mask is too high
def count_on_line_segment(mask, x1, y1, x2, y2, num_points=100):
#points = []
score = 0
#score_vertex = 0
diffx = x2 - x1
diffy = y2 - y1
for t in range(num_points + 1):
t /= num_points
x = x1 + t * diffx
y = y1 + t * diffy
x, y = x.astype(np.int32), y.astype(np.int32)
if mask[y,x] > 0:
score += 1
#if apex_mask[y,x] > 0:
# score_vertex += 1
return score/num_points #, score_vertex/num_points
#points.append((x, y))
#return points
conn_thr = 0.8 # 80% of pixels are connectivity pixels
for p1i in apex_pts_idxs:
for p2i in apex_pts_idxs:
if p1i == p2i:
continue
score = count_on_line_segment(mask, apex_pts[p1i][0], apex_pts[p1i][1], apex_pts[p2i][0], apex_pts[p2i][1], num_points=100)
#print(f'{p1i}, {p2i}, score = {score}')
if score>conn_thr and ((p2i,p1i) not in conn):
conn.append((p1i, p2i))
return conn, mask
connections, line_img = classifyPairs(apex_pts, apex_pts_idxs, gest_seg_np, apex_mask, eave_end_mask)
#print(f'{len(vertices)} vertices: {vertices}')
#print(len(connections), ' connections: ', connections)
if DUMP_IMG:
filename_edges_map = f'edges_map_{rid}.jpg'
if 'line_img' in locals():
cv2.imwrite(filename_edges_map, line_img)
'''
connections = []
return vertices, connections
def merge_vertices_3d(vert_edge_per_image, th=0.1):
'''Merge vertices that are close to each other in 3D space and are of same types'''
all_3d_vertices = []
connections_3d = []
all_indexes = []
cur_start = 0
types = []
for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
types += [int(v['type']=='apex') for v in vertices]
all_3d_vertices.append(vertices_3d)
connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections]
cur_start+=len(vertices_3d)
all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
distmat = cdist(all_3d_vertices, all_3d_vertices)
types = np.array(types).reshape(-1,1)
same_types = cdist(types, types)
mask_to_merge = (distmat <= th) & (same_types==0)
new_vertices = []
new_connections = []
to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
to_merge_final = defaultdict(list)
for i in range(len(all_3d_vertices)):
for j in to_merge:
if i in j:
to_merge_final[i]+=j
for k, v in to_merge_final.items():
to_merge_final[k] = list(set(v))
already_there = set()
merged = []
for k, v in to_merge_final.items():
if k in already_there:
continue
merged.append(v)
for vv in v:
already_there.add(vv)
old_idx_to_new = {}
count=0
for idxs in merged:
new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
for idx in idxs:
old_idx_to_new[idx] = count
count +=1
new_vertices=np.array(new_vertices)
for conn in connections_3d:
new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
if new_con[0] == new_con[1]:
continue
if new_con not in new_connections:
new_connections.append(new_con)
return new_vertices, new_connections
def prune_not_connected(all_3d_vertices, connections_3d):
'''Prune vertices that are not connected to any other vertex'''
connected = defaultdict(list)
for c in connections_3d:
connected[c[0]].append(c)
connected[c[1]].append(c)
new_indexes = {}
new_verts = []
connected_out = []
for k,v in connected.items():
vert = all_3d_vertices[k]
if tuple(vert) not in new_verts:
new_verts.append(tuple(vert))
new_indexes[k]=len(new_verts) -1
for k,v in connected.items():
for vv in v:
connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]]))
connected_out=list(set(connected_out))
return np.array(new_verts), connected_out
def uv_to_v3d(uv, depth_vert, K, R, t):
# Normalize the uv to the camera intrinsics
xy_local = np.ones((len(uv), 3))
xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0]
xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1]
# Get the 3D vertices
vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None])
world_to_cam = np.eye(4)
world_to_cam[:3, :3] = R
world_to_cam[:3, 3] = t.reshape(-1)
cam_to_world = np.linalg.inv(world_to_cam)
vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world)
vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
return vertices_3d
def delete_one_vert(vertices, vertices_3d, connections, vert_to_del):
i = np.where(np.all(abs(vertices_3d - vert_to_del) < 0.01, axis=1))
if len(i[0])==0:
if vertices:
return vertices, vertices_3d, connections
else:
return vertices, vertices_3d, connections
idx = i[0]#[0]
if vertices:
vertices = np.delete(vertices, idx)
vertices_3d = np.delete(vertices_3d, idx, axis=0)
conn_to_del = []
for ic, c in enumerate(connections):
if c[0] == idx or c[1] == idx:
conn_to_del.append(ic)
connections = np.delete(connections, (conn_to_del), axis=0)
for ic, c in enumerate(connections):
if c[0] >= idx:
connections[ic] = (connections[ic][0]-1, connections[ic][1])
if c[1] >= idx:
connections[ic] = (connections[ic][0], connections[ic][1]-1)
connections = connections.tolist()
if vertices:
return vertices, vertices_3d, connections
else:
return vertices_3d, connections
def prune_far(all_3d_vertices, connections_3d, prune_dist_thr=3000):
'''Prune vertices that are far away from any other vertices'''
if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
return all_3d_vertices, connections_3d
isolated = []
distmat = cdist(all_3d_vertices, all_3d_vertices)
for i, v in enumerate(distmat):
exclude_self = np.array([x for idx,x in enumerate(v) if idx!=i])
exclude_self = abs(exclude_self)
if min(exclude_self) > prune_dist_thr:
isolated.append(i)
break
while isolated:
isolated_pt = isolated.pop()
#print('isolated:', isolated_pt)
pt_to_del = all_3d_vertices[isolated_pt]
all_3d_vertices, connections_3d = delete_one_vert([], all_3d_vertices, connections_3d, pt_to_del)
if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
return all_3d_vertices, connections_3d
distmat = cdist(all_3d_vertices, all_3d_vertices)
for i, v in enumerate(distmat):
exclude_self = np.array([x for idx,x in enumerate(v) if idx!=i])
#if np.any(exclude_self > prune_dist_thr):
exclude_self = abs(exclude_self)
if min(exclude_self) > prune_dist_thr:
#print('del a pt w/ dist = ', min(exclude_self))
isolated.append(i)
break
return all_3d_vertices, connections_3d
def prune_tall_short(all_3d_vertices, connections_3d, lowest_z, prune_tall_thr=1000, prune_short_thr=100):
'''Prune vertices that has inpractical z'''
if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
return all_3d_vertices, connections_3d
isolated = []
for i,v in enumerate(all_3d_vertices):
if v[2]-lowest_z > prune_tall_thr or v[2]-lowest_z < prune_short_thr:
isolated.append(i)
break
while isolated:
isolated_pt = isolated.pop()
#print('isolated:', isolated_pt)
pt_to_del = all_3d_vertices[isolated_pt]
all_3d_vertices, connections_3d = delete_one_vert([], all_3d_vertices, connections_3d, pt_to_del)
if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
return all_3d_vertices, connections_3d
for i,v in enumerate(all_3d_vertices):
if v[2]-lowest_z > prune_tall_thr or v[2]-lowest_z < prune_short_thr:
isolated.append(i)
break
return all_3d_vertices, connections_3d
def clean_gest(gest_seg_np):
'''
Remove all blobs that are not conencted to the largest blob
'''
bg_color = np.array(gestalt_color_mapping['unclassified'])
bg_mask = cv2.inRange(gest_seg_np, bg_color-10, bg_color+10)
if bg_mask.sum() == 0 or bg_mask.sum() == gest_seg_np.shape[0]*gest_seg_np.shape[1]:
return gest_seg_np
fg_mask = cv2.bitwise_not(bg_mask)
if fg_mask.sum() > 0:
output = cv2.connectedComponentsWithStats(fg_mask, 8, cv2.CV_32S)
(numLabels, labels, stats, centroids) = output
sizes = stats[1:, -1] # Get the areas (skip the first entry which is the background)
max_area = max(sizes)
max_label = np.where(sizes == max_area)[0] + 1 # Add 1 to get the actual label
# mask out anything that doesn't belong to the largest component
gest_seg_np[labels != max_label] = bg_color
return gest_seg_np
def clean_PCD(XYZ, rgb):
'''
Remove all points that do not belong to the largest cluster
'''
lowest_z = 0
center_thr = 500
largest_blob_size = 0
largest_blob = 0
# avoid memory issue
if len(XYZ) > 130000 or len(XYZ) < 20:
return XYZ, rgb, lowest_z
# clustering
clust = OPTICS(min_samples=20, max_eps=150, metric='euclidean', cluster_method='dbscan', algorithm='kd_tree').fit(XYZ)
labels = clust.labels_
unique_labels = set(labels)
retain_class_mask = labels == -2
if len(unique_labels) > 40 or len(unique_labels) == 1:
return XYZ, rgb, lowest_z
for k in unique_labels:
class_member_mask = labels == k
blob_size = np.count_nonzero(class_member_mask)
if blob_size>largest_blob_size:
largest_blob_size = blob_size
largest_blob = k
for k in unique_labels:
'''
# -1 is the noise cluster
if k == -1:
retain_class_mask = retain_class_mask | class_member_mask
continue
'''
''' center prior is not valid
pt_k = XYZ[class_member_mask]
Xmean = np.mean(pt_k[:,0])
Ymean = np.mean(pt_k[:,1])
if abs(Xmean) < center_thr and abs(Ymean) < center_thr:
retain_class_mask = retain_class_mask | class_member_mask
'''
if k == largest_blob:
class_member_mask = labels == k
retain_class_mask = retain_class_mask | class_member_mask
#pt_k = XYZ[class_member_mask]
#lowest_z = min(pt_k[:,2])
break
XYZ = XYZ[retain_class_mask]
rgb = rgb[retain_class_mask]
return XYZ, rgb, lowest_z
def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tuple[np.ndarray, List[int]]:
good_entry = convert_entry_to_human_readable(entry)
points3D = good_entry['points3d']
XYZ = np.stack([p.xyz for p in points3D.values()])
rgb = np.stack([p.rgb for p in points3D.values()])
lowest_z = min(XYZ[:,2])
XYZ, rgb, lowest_z = clean_PCD(XYZ, rgb)
del points3D
vert_edge_per_image = {}
for i, (ade, gest, depth, K, R, t) in enumerate(zip(
good_entry['ade20k'],
good_entry['gestalt'],
good_entry['depthcm'],
good_entry['K'],
good_entry['R'],
good_entry['t']
)):
'''
debug per view
if i!=3:
continue
'''
# (1) 2D processing
ade_seg = ade.resize(depth.size)
ade_seg_np = np.array(ade_seg).astype(np.uint8)
gest_seg = gest.resize(depth.size)
gest_seg_np = np.array(gest_seg).astype(np.uint8)
gest_seg_np = clean_gest(gest_seg_np)
# Metric3D
depth_np = np.array(depth) / depth_scale
vertices, connections = get_vertices_and_edges_from_two_segmentations(ade_seg_np, gest_seg_np, edge_th = 50.)
if (len(vertices) < 1):
vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3))
continue
# (2) Use depth
sfm_depth_np = get_SfM_depth(XYZ, rgb, depth_np, gest_seg_np, K, R, t, 5)
uv, depth_vert = get_smooth_uv_depth(vertices, depth_np, gest_seg_np, sfm_depth_np, 75)
vertices_3d = uv_to_v3d(uv, depth_vert, K, R, t)
vert_edge_per_image[i] = vertices, connections, vertices_3d
# (3) aggregate info collected from all views:
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 150)
#all_3d_vertices, connections_3d = prune_tall_short(all_3d_vertices, connections_3d, lowest_z, 1000, 0)
''' This didn't help the final solution
if len(all_3d_vertices)>35:
all_3d_vertices, connections_3d = prune_not_connected(all_3d_vertices, connections_3d)
'''
if len(all_3d_vertices)>10:
all_3d_vertices_clean, connections_3d_clean = prune_far(all_3d_vertices, connections_3d, prune_dist_thr=prune_dist_thr)
else:
all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
connections_3d_clean = []
if (len(all_3d_vertices_clean) < 2):
print (f'Not enough vertices or connections in the 3D vertices')
return (good_entry['__key__'], *empty_solution())
if visualize:
print(f"num of est: {len(all_3d_vertices_clean)}, num of gt:{len(good_entry['wf_vertices'])}")
from hoho.viz3d import plot_estimate_and_gt
plot_estimate_and_gt( all_3d_vertices_clean,
connections_3d_clean,
good_entry['wf_vertices'],
good_entry['wf_edges'])
return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean