# Description: This file contains the handcrafted solution for the task of wireframe reconstruction import io from PIL import Image as PImage import numpy as np from collections import defaultdict import cv2 from typing import Tuple, List from scipy.spatial.distance import cdist from sklearn.cluster import DBSCAN, OPTICS from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping DUMP_IMG = False if DUMP_IMG: from scipy.sparse import random def empty_solution(): '''Return a minimal valid solution, i.e. 2 vertices and 0 edge.''' return np.zeros((2,3)), [] def convert_entry_to_human_readable(entry): out = {} already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't'] for k, v in entry.items(): if k in already_good: out[k] = v continue if k == 'points3d': out[k] = read_points3D_binary(fid=io.BytesIO(v)) if k == 'cameras': out[k] = read_cameras_binary(fid=io.BytesIO(v)) if k == 'images': out[k] = read_images_binary(fid=io.BytesIO(v)) if k in ['ade20k', 'gestalt']: out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v] if k == 'depthcm': out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']] return out def get_uv_depth(vertices, depth): '''Get the depth of the vertices from the depth image''' uv = [] for v in vertices: uv.append(v['xy']) uv = np.array(uv) uv_int = uv.astype(np.int32) H, W = depth.shape[:2] uv_int[:, 0] = np.clip( uv_int[:, 0], 0, W-1) uv_int[:, 1] = np.clip( uv_int[:, 1], 0, H-1) vertex_depth = depth[(uv_int[:, 1] , uv_int[:, 0])] return uv, vertex_depth def get_smooth_uv_depth(vertices, depth, gest_seg_np, sfm_depth_np, r=5): '''Get the depth of the vertices from the depth image''' uv = [] for v in vertices: uv.append(v['xy']) uv = np.array(uv) uv_int = uv.astype(np.int32) H, W = depth.shape[:2] a = np.clip( uv_int[:, 0], 0, W-1) b = np.clip( uv_int[:, 1], 0, H-1) def get_local_depth(x,y, H, W, depth, r=r): '''return a smooth version of detph in radius r''' local_depths = [] for i in range(max(0, x - r), min(W, x + r)): for j in range(max(0, y - r), min(H, y + r)): if np.sqrt((i - x)**2 + (j - y)**2) <= r: if sfm_depth_np is not None: if sfm_depth_np[j, i] != 0: local_depths.append(sfm_depth_np[j, i]) else: local_depths.append(depth[j, i]) else: local_depths.append(depth[j, i]) return local_depths def get_local_min(x,y, H, W, depth, sfm_depth_np, r=r, PRINT=False): '''return a smooth version of detph in radius r''' local_min = 9999999 i_range = range(max(0, x - r), min(W, x + r)) j_range = range(max(0, y - r), min(H, y + r)) for i in i_range: for j in j_range: if sfm_depth_np is not None: if sfm_depth_np[j, i] != 0: local_min = min(sfm_depth_np[j, i], local_min) if PRINT: print(f'({j},{i})sfm:', sfm_depth_np[j, i]) else: local_min = min(depth[j, i], local_min) else: local_min = min(depth[j, i], local_min) return local_min def get_priotity_local_min(x,y, H, W, depth, sfm_depth_np, r=r): ''' Search on sfm depth first. Search on depthmap only if no sfm depth exists at all in the local region. ''' PRINT = False r_choices = [5, 10, 20, 40, 75, 200] for r in r_choices: yslice = slice(max(0, y - r), min(H, y + r)) xslice = slice(max(0, x - r), min(W, x + r)) local_area = sfm_depth_np[yslice, xslice] reduced_local_area = local_area[local_area!=0] if reduced_local_area.size > 0: break if reduced_local_area.size > 0: #print('use sfm') if PRINT: print(reduced_local_area) local_min = np.min(reduced_local_area) return local_min else: #print('use both sfm and monocular') return get_local_min(x,y, H, W, depth, sfm_depth_np, r, PRINT) def get_local_min_progressive(x,y, H, W, depth, sfm_depth_np, r=r): ''' If sfm is available in small local region, use it. Otherwise, search in large region with combined depth ''' small_r, large_r = 5, 75 PRINT= False r = small_r yslice = slice(max(0, y - r), min(H, y + r)) xslice = slice(max(0, x - r), min(W, x + r)) if np.any(sfm_depth_np[yslice, xslice] != 0): return get_local_min(x,y, H, W, depth, sfm_depth_np, r) else: r = large_r local_min = 9999999 i_range = range(max(0, x - r), min(W, x + r)) j_range = range(max(0, y - r), min(H, y + r)) for i in i_range: for j in j_range: if sfm_depth_np is not None: if sfm_depth_np[j, i] != 0: local_min = min(sfm_depth_np[j, i], local_min) if PRINT: print(sfm_depth_np[j, i]) else: local_min = min(depth[j, i], local_min) if PRINT: print('dm:', depth[j, i]) else: local_min = min(depth[j, i], local_min) if PRINT: print('dm:', depth[j, i]) return local_min vertex_depth = [] for x,y in zip(a,b): local_min = get_priotity_local_min(x,y, H, W, depth, sfm_depth_np, r) vertex_depth.append(local_min) ''' local_depths = get_local_depth(x,y, H, W, depth, 5) #local_mean = np.mean(local_depths) local_mean = np.min(local_depths) vertex_depth.append(local_mean) ''' vertex_depth = np.array(vertex_depth) return uv, vertex_depth ''' Turn on this to speed up if you have numba from numba import njit, prange @njit(parallel=True) def fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W): for i in prange(max(0, u - dilate_r), min(W, u + dilate_r)): for j in prange(max(0, v - dilate_r), min(H, v + dilate_r)) : #checked+=1 existing_z = sfm_depth_np[j, i] if z > 0: if (existing_z!=0 and z < existing_z) or (existing_z==0): sfm_depth_np[j, i] = z if DUMP_IMG: sfm_color_np[j, i] = c return sfm_depth_np, sfm_color_np ''' def get_SfM_depth(XYZ, rgb, depth_np, gest_seg_np, K, R, t, dilate_r = 5): '''Project 3D sfm pointcloud to the image plane ''' H, W = depth_np.shape[:2] sfm_depth_np = np.zeros(depth_np.shape) sfm_color_np = np.zeros(gest_seg_np.shape) XYZ1 = np.concatenate((XYZ, np.ones((len(XYZ), 1))), axis=1) Rt = np.concatenate( (R, t.reshape((3,1))), axis=1) world_to_cam = K @ Rt xyz = world_to_cam @ XYZ1.transpose() xyz = np.transpose(xyz) valid_idx = ~np.isclose(xyz[:,2], 0, atol=1e-2) & ~np.isnan(xyz[:,0]) & ~np.isnan(xyz[:,1]) & ~np.isnan(xyz[:,2]) xyz = xyz[valid_idx, :] us, vs, zs = xyz[:,0]/xyz[:,2], xyz[:,1]/xyz[:,2], xyz[:,2] us = us[~np.isnan(us)] vs = vs[~np.isnan(vs)] us = us.astype(np.int32) vs = vs.astype(np.int32) for u,v,z,c in zip(us,vs,zs, rgb): ''' Use this insead if you have numba sfm_depth_np, sfm_color_np = fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W) ''' i_range = range(max(0, u - dilate_r), min(W, u + dilate_r)) j_range = range(max(0, v - dilate_r), min(H, v + dilate_r)) for i in i_range: for j in j_range: #checked+=1 existing_z = sfm_depth_np[j, i] if z > 0: if (existing_z!=0 and z < existing_z) or (existing_z==0): sfm_depth_np[j, i] = z if DUMP_IMG: sfm_color_np[j, i] = c if DUMP_IMG: filename_sfm_depth = 'sfm_depth.png' cv2.imwrite(filename_sfm_depth, sfm_depth_np/100) filename_sfm_color = 'sfm_color.png' cv2.imwrite(filename_sfm_color, sfm_color_np) filename_ref_depth = 'ref_depth.png' cv2.imwrite(filename_ref_depth, depth_np/100) return sfm_depth_np def get_vertices_and_edges_from_two_segmentations(ade_seg_np, gest_seg_np, edge_th = 50.0): '''Get the vertices and edges from the gestalt segmentation mask of the house''' vertices = [] connections = [] color_th = 10.0 #------------------------- # combined map from ade if DUMP_IMG: ade_color0 = np.array([0,0,0]) ade_mask0 = cv2.inRange(ade_seg_np, ade_color0-0.5, ade_color0+0.5) ade_color1 = np.array([120,120,120]) ade_mask1 = cv2.inRange(ade_seg_np, ade_color1-0.5, ade_color1+0.5) ade_color2 = np.array([180,120,120]) ade_mask2 = cv2.inRange(ade_seg_np, ade_color2-0.5, ade_color2+0.5) ade_color3 = np.array([255,9,224]) ade_mask3 = cv2.inRange(ade_seg_np, ade_color3-0.5, ade_color3+0.5) ade_mask = cv2.bitwise_or(ade_mask3, ade_mask2) ade_mask = cv2.bitwise_or(ade_mask1, ade_mask) apex_map = np.zeros(ade_seg_np.shape) apex_map_on_ade = ade_seg_np apex_map_on_gest = gest_seg_np # Apex apex_color = np.array(gestalt_color_mapping['apex']) apex_mask = cv2.inRange(gest_seg_np, apex_color-color_th, apex_color+color_th) # include more pts #apex_mask = cv2.bitwise_and(apex_mask, ade_mask) # remove pts if apex_mask.sum() > 0: output = cv2.connectedComponentsWithStats(apex_mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output stats, centroids = stats[1:], centroids[1:] for i in range(numLabels-1): vert = {"xy": centroids[i], "type": "apex"} vertices.append(vert) if DUMP_IMG: uu = int(centroids[i][1]) vv = int(centroids[i][0]) # plot a cross apex_map_on_ade[uu, vv] = (255,255,255) shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)] h,w,_ = apex_map_on_ade.shape for ss in shift: if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w: apex_map[uu+ss[0], vv+ss[1]] = (255,255,255) apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,255,255) apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,255,255) eave_end_color = np.array(gestalt_color_mapping['eave_end_point']) eave_end_mask = cv2.inRange(gest_seg_np, eave_end_color-color_th, eave_end_color+color_th) if eave_end_mask.sum() > 0: output = cv2.connectedComponentsWithStats(eave_end_mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output stats, centroids = stats[1:], centroids[1:] for i in range(numLabels-1): vert = {"xy": centroids[i], "type": "eave_end_point"} vertices.append(vert) if DUMP_IMG: uu = int(centroids[i][1]) vv = int(centroids[i][0]) # plot a cross apex_map_on_ade[uu, vv] = (255,0,0) shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)] h,w,_ = apex_map_on_ade.shape for ss in shift: if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w: apex_map[uu+ss[0], vv+ss[1]] = (255,0,0) apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,0,0) apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,0,0) flashing_end_color = np.array(gestalt_color_mapping['flashing_end_point']) flashing_end_mask = cv2.inRange(gest_seg_np, flashing_end_color-color_th/2, flashing_end_color+color_th/2) # this color is sensitive if flashing_end_color.sum() > 0: output = cv2.connectedComponentsWithStats(flashing_end_mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output stats, centroids = stats[1:], centroids[1:] for i in range(numLabels-1): vert = {"xy": centroids[i], "type": "flashing_end_point"} vertices.append(vert) if DUMP_IMG: uu = int(centroids[i][1]) vv = int(centroids[i][0]) # plot a cross apex_map_on_ade[uu, vv] = (255,0,0) shift=[(1,0),(-1,0),(0,1),(0,-1), (2,0),(-2,0),(0,2),(0,-2), (3,0),(-3,0),(0,3),(0,-3)] h,w,_ = apex_map_on_ade.shape for ss in shift: if uu+ss[0] >= 0 and uu+ss[0] < h and vv+ss[1] >= 0 and vv+ss[1] < w: apex_map[uu+ss[0], vv+ss[1]] = (255,0,0) apex_map_on_ade[uu+ss[0], vv+ss[1]] = (255,0,0) apex_map_on_gest[uu+ss[0], vv+ss[1]] = (255,0,0) '''''' # imsave apex and eave_end if DUMP_IMG: import random rid = random.random() filename_apex_ade = f'apex_map_on_ade_{rid}.jpg' cv2.imwrite(filename_apex_ade, apex_map_on_ade) filename_apex_gest = f'apex_map_on_gest_{rid}.jpg' cv2.imwrite(filename_apex_gest, apex_map_on_gest) filename_apex_map = f'apex_map_{rid}.jpg' cv2.imwrite(filename_apex_map, apex_map) # Connectivity apex_pts = [] apex_pts_idxs = [] for j, v in enumerate(vertices): apex_pts.append(v['xy']) apex_pts_idxs.append(j) apex_pts = np.array(apex_pts) # Turns out connection is not a priority ''' # Ridge connects two apex points def Ridge_connects_two_apex_points(gest_seg_np, color_th, apex_pts, edge_th): conn = [] line_img = np.copy(gest_seg_np) * 0 for edge_class in ['eave', 'ridge', 'rake', 'valley']: edge_color = np.array(gestalt_color_mapping[edge_class]) mask = cv2.morphologyEx(cv2.inRange(gest_seg_np, edge_color-color_th, edge_color+color_th), cv2.MORPH_DILATE, np.ones((11, 11))) #line_img = np.copy(gest_seg_np) * 0 if mask.sum() > 0: output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output stats, centroids = stats[1:], centroids[1:] edges = [] for i in range(1, numLabels): y,x = np.where(labels == i) xleft_idx = np.argmin(x) x_left = x[xleft_idx] y_left = y[xleft_idx] xright_idx = np.argmax(x) x_right = x[xright_idx] y_right = y[xright_idx] edges.append((x_left, y_left, x_right, y_right)) cv2.line(line_img, (x_left, y_left), (x_right, y_right), (255, 255, 255), 2) edges = np.array(edges) if (len(apex_pts) < 2) or len(edges) <1: continue pts_to_edges_dist = np.minimum(cdist(apex_pts, edges[:,:2]), cdist(apex_pts, edges[:,2:])) connectivity_mask = pts_to_edges_dist <= edge_th edge_connects = connectivity_mask.sum(axis=0) for edge_idx, edgesum in enumerate(edge_connects): if edgesum>=2: connected_verts = np.where(connectivity_mask[:,edge_idx])[0] for a_i, a in enumerate(connected_verts): for b in connected_verts[a_i+1:]: conn.append((a, b)) return conn, line_img connections, line_img = Ridge_connects_two_apex_points(gest_seg_np, color_th, apex_pts, edge_th) ''' ''' def classifyPairs(apex_pts, apex_pts_idxs, gest_seg_np, apex_mask, eave_end_mask): conn = [] # Plot all possible connection pixels in one mask mask = cv2.bitwise_or(apex_mask, eave_end_mask) #for edge_class in ['eave', 'ridge', 'rake', 'valley', 'step_flashing' ]:#, 'flashing']: for edge_class in ['eave', 'ridge', 'rake', 'valley', 'step_flashing' , 'flashing']: edge_color = np.array(gestalt_color_mapping[edge_class]) mask_e = cv2.morphologyEx(cv2.inRange(gest_seg_np, edge_color-color_th, edge_color+color_th), cv2.MORPH_DILATE, np.ones((11, 11))) mask = cv2.bitwise_or(mask, mask_e) # try connecting each apir and see if the cost on the mask is too high def count_on_line_segment(mask, x1, y1, x2, y2, num_points=100): #points = [] score = 0 #score_vertex = 0 diffx = x2 - x1 diffy = y2 - y1 for t in range(num_points + 1): t /= num_points x = x1 + t * diffx y = y1 + t * diffy x, y = x.astype(np.int32), y.astype(np.int32) if mask[y,x] > 0: score += 1 #if apex_mask[y,x] > 0: # score_vertex += 1 return score/num_points #, score_vertex/num_points #points.append((x, y)) #return points conn_thr = 0.8 # 80% of pixels are connectivity pixels for p1i in apex_pts_idxs: for p2i in apex_pts_idxs: if p1i == p2i: continue score = count_on_line_segment(mask, apex_pts[p1i][0], apex_pts[p1i][1], apex_pts[p2i][0], apex_pts[p2i][1], num_points=100) #print(f'{p1i}, {p2i}, score = {score}') if score>conn_thr and ((p2i,p1i) not in conn): conn.append((p1i, p2i)) return conn, mask connections, line_img = classifyPairs(apex_pts, apex_pts_idxs, gest_seg_np, apex_mask, eave_end_mask) #print(f'{len(vertices)} vertices: {vertices}') #print(len(connections), ' connections: ', connections) if DUMP_IMG: filename_edges_map = f'edges_map_{rid}.jpg' if 'line_img' in locals(): cv2.imwrite(filename_edges_map, line_img) ''' connections = [] return vertices, connections def merge_vertices_3d(vert_edge_per_image, th=0.1): '''Merge vertices that are close to each other in 3D space and are of same types''' all_3d_vertices = [] connections_3d = [] all_indexes = [] cur_start = 0 types = [] for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items(): types += [int(v['type']=='apex') for v in vertices] all_3d_vertices.append(vertices_3d) connections_3d+=[(x+cur_start,y+cur_start) for (x,y) in connections] cur_start+=len(vertices_3d) all_3d_vertices = np.concatenate(all_3d_vertices, axis=0) distmat = cdist(all_3d_vertices, all_3d_vertices) types = np.array(types).reshape(-1,1) same_types = cdist(types, types) mask_to_merge = (distmat <= th) & (same_types==0) new_vertices = [] new_connections = [] to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge]))) to_merge_final = defaultdict(list) for i in range(len(all_3d_vertices)): for j in to_merge: if i in j: to_merge_final[i]+=j for k, v in to_merge_final.items(): to_merge_final[k] = list(set(v)) already_there = set() merged = [] for k, v in to_merge_final.items(): if k in already_there: continue merged.append(v) for vv in v: already_there.add(vv) old_idx_to_new = {} count=0 for idxs in merged: new_vertices.append(all_3d_vertices[idxs].mean(axis=0)) for idx in idxs: old_idx_to_new[idx] = count count +=1 new_vertices=np.array(new_vertices) for conn in connections_3d: new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]])) if new_con[0] == new_con[1]: continue if new_con not in new_connections: new_connections.append(new_con) return new_vertices, new_connections def prune_not_connected(all_3d_vertices, connections_3d): '''Prune vertices that are not connected to any other vertex''' connected = defaultdict(list) for c in connections_3d: connected[c[0]].append(c) connected[c[1]].append(c) new_indexes = {} new_verts = [] connected_out = [] for k,v in connected.items(): vert = all_3d_vertices[k] if tuple(vert) not in new_verts: new_verts.append(tuple(vert)) new_indexes[k]=len(new_verts) -1 for k,v in connected.items(): for vv in v: connected_out.append((new_indexes[vv[0]],new_indexes[vv[1]])) connected_out=list(set(connected_out)) return np.array(new_verts), connected_out def uv_to_v3d(uv, depth_vert, K, R, t): # Normalize the uv to the camera intrinsics xy_local = np.ones((len(uv), 3)) xy_local[:, 0] = (uv[:, 0] - K[0,2]) / K[0,0] xy_local[:, 1] = (uv[:, 1] - K[1,2]) / K[1,1] # Get the 3D vertices vertices_3d_local = depth_vert[...,None] * (xy_local/np.linalg.norm(xy_local, axis=1)[...,None]) world_to_cam = np.eye(4) world_to_cam[:3, :3] = R world_to_cam[:3, 3] = t.reshape(-1) cam_to_world = np.linalg.inv(world_to_cam) vertices_3d = cv2.transform(cv2.convertPointsToHomogeneous(vertices_3d_local), cam_to_world) vertices_3d = cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3) return vertices_3d def delete_one_vert(vertices, vertices_3d, connections, vert_to_del): i = np.where(np.all(abs(vertices_3d - vert_to_del) < 0.01, axis=1)) if len(i[0])==0: if vertices: return vertices, vertices_3d, connections else: return vertices, vertices_3d, connections idx = i[0]#[0] if vertices: vertices = np.delete(vertices, idx) vertices_3d = np.delete(vertices_3d, idx, axis=0) conn_to_del = [] for ic, c in enumerate(connections): if c[0] == idx or c[1] == idx: conn_to_del.append(ic) connections = np.delete(connections, (conn_to_del), axis=0) for ic, c in enumerate(connections): if c[0] >= idx: connections[ic] = (connections[ic][0]-1, connections[ic][1]) if c[1] >= idx: connections[ic] = (connections[ic][0], connections[ic][1]-1) connections = connections.tolist() if vertices: return vertices, vertices_3d, connections else: return vertices_3d, connections def prune_far(all_3d_vertices, connections_3d, prune_dist_thr=3000): '''Prune vertices that are far away from any other vertices''' if (len(all_3d_vertices) < 3) or len(connections_3d) < 1: return all_3d_vertices, connections_3d isolated = [] distmat = cdist(all_3d_vertices, all_3d_vertices) for i, v in enumerate(distmat): exclude_self = np.array([x for idx,x in enumerate(v) if idx!=i]) exclude_self = abs(exclude_self) if min(exclude_self) > prune_dist_thr: isolated.append(i) break while isolated: isolated_pt = isolated.pop() #print('isolated:', isolated_pt) pt_to_del = all_3d_vertices[isolated_pt] all_3d_vertices, connections_3d = delete_one_vert([], all_3d_vertices, connections_3d, pt_to_del) if (len(all_3d_vertices) < 3) or len(connections_3d) < 1: return all_3d_vertices, connections_3d distmat = cdist(all_3d_vertices, all_3d_vertices) for i, v in enumerate(distmat): exclude_self = np.array([x for idx,x in enumerate(v) if idx!=i]) #if np.any(exclude_self > prune_dist_thr): exclude_self = abs(exclude_self) if min(exclude_self) > prune_dist_thr: #print('del a pt w/ dist = ', min(exclude_self)) isolated.append(i) break return all_3d_vertices, connections_3d def prune_tall_short(all_3d_vertices, connections_3d, lowest_z, prune_tall_thr=1000, prune_short_thr=100): '''Prune vertices that has inpractical z''' if (len(all_3d_vertices) < 3) or len(connections_3d) < 1: return all_3d_vertices, connections_3d isolated = [] for i,v in enumerate(all_3d_vertices): if v[2]-lowest_z > prune_tall_thr or v[2]-lowest_z < prune_short_thr: isolated.append(i) break while isolated: isolated_pt = isolated.pop() #print('isolated:', isolated_pt) pt_to_del = all_3d_vertices[isolated_pt] all_3d_vertices, connections_3d = delete_one_vert([], all_3d_vertices, connections_3d, pt_to_del) if (len(all_3d_vertices) < 3) or len(connections_3d) < 1: return all_3d_vertices, connections_3d for i,v in enumerate(all_3d_vertices): if v[2]-lowest_z > prune_tall_thr or v[2]-lowest_z < prune_short_thr: isolated.append(i) break return all_3d_vertices, connections_3d def clean_gest(gest_seg_np): ''' Remove all blobs that are not conencted to the largest blob ''' bg_color = np.array(gestalt_color_mapping['unclassified']) bg_mask = cv2.inRange(gest_seg_np, bg_color-10, bg_color+10) if bg_mask.sum() == 0 or bg_mask.sum() == gest_seg_np.shape[0]*gest_seg_np.shape[1]: return gest_seg_np fg_mask = cv2.bitwise_not(bg_mask) if fg_mask.sum() > 0: output = cv2.connectedComponentsWithStats(fg_mask, 8, cv2.CV_32S) (numLabels, labels, stats, centroids) = output sizes = stats[1:, -1] # Get the areas (skip the first entry which is the background) max_area = max(sizes) max_label = np.where(sizes == max_area)[0] + 1 # Add 1 to get the actual label # mask out anything that doesn't belong to the largest component gest_seg_np[labels != max_label] = bg_color return gest_seg_np def clean_PCD(XYZ, rgb): ''' Remove all points that do not belong to the largest cluster ''' lowest_z = 0 center_thr = 500 largest_blob_size = 0 largest_blob = 0 # avoid memory issue if len(XYZ) > 130000 or len(XYZ) < 20: return XYZ, rgb, lowest_z # clustering clust = OPTICS(min_samples=20, max_eps=150, metric='euclidean', cluster_method='dbscan', algorithm='kd_tree').fit(XYZ) labels = clust.labels_ unique_labels = set(labels) retain_class_mask = labels == -2 if len(unique_labels) > 40 or len(unique_labels) == 1: return XYZ, rgb, lowest_z for k in unique_labels: class_member_mask = labels == k blob_size = np.count_nonzero(class_member_mask) if blob_size>largest_blob_size: largest_blob_size = blob_size largest_blob = k for k in unique_labels: ''' # -1 is the noise cluster if k == -1: retain_class_mask = retain_class_mask | class_member_mask continue ''' ''' center prior is not valid pt_k = XYZ[class_member_mask] Xmean = np.mean(pt_k[:,0]) Ymean = np.mean(pt_k[:,1]) if abs(Xmean) < center_thr and abs(Ymean) < center_thr: retain_class_mask = retain_class_mask | class_member_mask ''' if k == largest_blob: class_member_mask = labels == k retain_class_mask = retain_class_mask | class_member_mask #pt_k = XYZ[class_member_mask] #lowest_z = min(pt_k[:,2]) break XYZ = XYZ[retain_class_mask] rgb = rgb[retain_class_mask] return XYZ, rgb, lowest_z def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tuple[np.ndarray, List[int]]: good_entry = convert_entry_to_human_readable(entry) points3D = good_entry['points3d'] XYZ = np.stack([p.xyz for p in points3D.values()]) rgb = np.stack([p.rgb for p in points3D.values()]) lowest_z = min(XYZ[:,2]) XYZ, rgb, lowest_z = clean_PCD(XYZ, rgb) del points3D vert_edge_per_image = {} for i, (ade, gest, depth, K, R, t) in enumerate(zip( good_entry['ade20k'], good_entry['gestalt'], good_entry['depthcm'], good_entry['K'], good_entry['R'], good_entry['t'] )): ''' debug per view if i!=3: continue ''' # (1) 2D processing ade_seg = ade.resize(depth.size) ade_seg_np = np.array(ade_seg).astype(np.uint8) gest_seg = gest.resize(depth.size) gest_seg_np = np.array(gest_seg).astype(np.uint8) gest_seg_np = clean_gest(gest_seg_np) # Metric3D depth_np = np.array(depth) / depth_scale vertices, connections = get_vertices_and_edges_from_two_segmentations(ade_seg_np, gest_seg_np, edge_th = 50.) if (len(vertices) < 1): vert_edge_per_image[i] = np.empty((0, 2)), [], np.empty((0, 3)) continue # (2) Use depth sfm_depth_np = get_SfM_depth(XYZ, rgb, depth_np, gest_seg_np, K, R, t, 5) uv, depth_vert = get_smooth_uv_depth(vertices, depth_np, gest_seg_np, sfm_depth_np, 75) vertices_3d = uv_to_v3d(uv, depth_vert, K, R, t) vert_edge_per_image[i] = vertices, connections, vertices_3d # (3) aggregate info collected from all views: all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 150) #all_3d_vertices, connections_3d = prune_tall_short(all_3d_vertices, connections_3d, lowest_z, 1000, 0) ''' This didn't help the final solution if len(all_3d_vertices)>35: all_3d_vertices, connections_3d = prune_not_connected(all_3d_vertices, connections_3d) ''' if len(all_3d_vertices)>10: all_3d_vertices_clean, connections_3d_clean = prune_far(all_3d_vertices, connections_3d, prune_dist_thr=prune_dist_thr) else: all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d connections_3d_clean = [] if (len(all_3d_vertices_clean) < 2): print (f'Not enough vertices or connections in the 3D vertices') return (good_entry['__key__'], *empty_solution()) if visualize: print(f"num of est: {len(all_3d_vertices_clean)}, num of gt:{len(good_entry['wf_vertices'])}") from hoho.viz3d import plot_estimate_and_gt plot_estimate_and_gt( all_3d_vertices_clean, connections_3d_clean, good_entry['wf_vertices'], good_entry['wf_edges']) return good_entry['__key__'], all_3d_vertices_clean, connections_3d_clean