kcml commited on
Commit
f111e99
·
1 Parent(s): 1ac89a4

clean gest

Browse files
Files changed (1) hide show
  1. handcrafted_solution.py +29 -11
handcrafted_solution.py CHANGED
@@ -251,7 +251,7 @@ def get_smooth_uv_depth(vertices, depth, gest_seg_np, sfm_depth_np, r=5):
251
  vertex_depth = np.array(vertex_depth)
252
  return uv, vertex_depth
253
 
254
- '''
255
  from numba import njit, prange
256
  @njit(parallel=True)
257
  def fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W):
@@ -265,7 +265,7 @@ def fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W):
265
  if DUMP_IMG:
266
  sfm_color_np[j, i] = c
267
  return sfm_depth_np, sfm_color_np
268
- '''
269
 
270
  def get_SfM_depth(points3D, depth_np, gest_seg_np, K, R, t, dilate_r = 5):
271
  '''Project 3D sfm pointcloud to the image plane '''
@@ -294,7 +294,7 @@ def get_SfM_depth(points3D, depth_np, gest_seg_np, K, R, t, dilate_r = 5):
294
  #checked = 0
295
  #print('dim of us uv zs rgb:', len(us), len(vs), len(zs), len(rgb))
296
  for u,v,z,c in zip(us,vs,zs, rgb):
297
- '''
298
  sfm_depth_np, sfm_color_np = fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W)
299
  '''
300
  i_range = range(max(0, u - dilate_r), min(W, u + dilate_r))
@@ -308,7 +308,7 @@ def get_SfM_depth(points3D, depth_np, gest_seg_np, K, R, t, dilate_r = 5):
308
  sfm_depth_np[j, i] = z
309
  if DUMP_IMG:
310
  sfm_color_np[j, i] = c
311
-
312
 
313
 
314
  #print(f'checked {checked} pts')
@@ -729,14 +729,14 @@ def prune_far(all_3d_vertices, connections_3d, prune_dist_thr=3000):
729
 
730
  return all_3d_vertices, connections_3d
731
 
732
- def prune_tall(all_3d_vertices, connections_3d, prune_tall_thr=1000):
733
  '''Prune vertices that has inpractical z'''
734
  if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
735
  return all_3d_vertices, connections_3d
736
 
737
  isolated = []
738
  for i,v in enumerate(all_3d_vertices):
739
- if v[2] > prune_tall_thr:
740
  isolated.append(i)
741
  break
742
 
@@ -749,12 +749,26 @@ def prune_tall(all_3d_vertices, connections_3d, prune_tall_thr=1000):
749
  return all_3d_vertices, connections_3d
750
 
751
  for i,v in enumerate(all_3d_vertices):
752
- if v[2] > prune_tall_thr:
753
  isolated.append(i)
754
  break
755
 
756
  return all_3d_vertices, connections_3d
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
 
759
  def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tuple[np.ndarray, List[int]]:
760
  good_entry = convert_entry_to_human_readable(entry)
@@ -775,13 +789,16 @@ def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tu
775
  elif i==2: # only visualize view 0,1
776
  continue
777
 
778
- if i!=0:
779
  continue
780
  '''
781
  ade_seg = ade.resize(depth.size)
782
  ade_seg_np = np.array(ade_seg).astype(np.uint8)
783
  gest_seg = gest.resize(depth.size)
784
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
 
 
 
785
  # Metric3D
786
  depth_np = np.array(depth) / depth_scale # / 2.5 # 2.5 is the scale estimation coefficient # don't use 2.5...
787
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 20.)
@@ -818,7 +835,8 @@ def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tu
818
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 150)
819
  #print(f'after merge, {len(all_3d_vertices)} 3d vertices and {len(connections_3d)} 3d connections')
820
  #all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
821
- all_3d_vertices, connections_3d = prune_tall(all_3d_vertices, connections_3d, 1000)
 
822
 
823
  if len(all_3d_vertices)>35:
824
  all_3d_vertices, connections_3d = prune_not_connected(all_3d_vertices, connections_3d)
@@ -826,13 +844,13 @@ def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tu
826
  all_3d_vertices_clean, connections_3d_clean = prune_far(all_3d_vertices, connections_3d, prune_dist_thr=prune_dist_thr)
827
  else:
828
  all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
829
-
830
  if i%2:
831
  connections_3d_clean = [connections_3d_clean[-1], connections_3d_clean[0]]
832
  else:
833
  connections_3d_clean = [connections_3d_clean[0]]
834
  #print('connections_3d_clean=', connections_3d_clean)
835
-
836
  #all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d # don't prune -> cost:2.0
837
  #print(f'after pruning, {len(all_3d_vertices_clean)} 3d clean vertices and {len(connections_3d_clean)} 3d clean connections')
838
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1:
 
251
  vertex_depth = np.array(vertex_depth)
252
  return uv, vertex_depth
253
 
254
+
255
  from numba import njit, prange
256
  @njit(parallel=True)
257
  def fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W):
 
265
  if DUMP_IMG:
266
  sfm_color_np[j, i] = c
267
  return sfm_depth_np, sfm_color_np
268
+ ''''''
269
 
270
  def get_SfM_depth(points3D, depth_np, gest_seg_np, K, R, t, dilate_r = 5):
271
  '''Project 3D sfm pointcloud to the image plane '''
 
294
  #checked = 0
295
  #print('dim of us uv zs rgb:', len(us), len(vs), len(zs), len(rgb))
296
  for u,v,z,c in zip(us,vs,zs, rgb):
297
+
298
  sfm_depth_np, sfm_color_np = fill_range(u, v, z, dilate_r, c, sfm_depth_np, sfm_color_np, H, W)
299
  '''
300
  i_range = range(max(0, u - dilate_r), min(W, u + dilate_r))
 
308
  sfm_depth_np[j, i] = z
309
  if DUMP_IMG:
310
  sfm_color_np[j, i] = c
311
+ '''
312
 
313
 
314
  #print(f'checked {checked} pts')
 
729
 
730
  return all_3d_vertices, connections_3d
731
 
732
+ def prune_tall_short(all_3d_vertices, connections_3d, prune_tall_thr=1000, prune_short_thr=100):
733
  '''Prune vertices that has inpractical z'''
734
  if (len(all_3d_vertices) < 3) or len(connections_3d) < 1:
735
  return all_3d_vertices, connections_3d
736
 
737
  isolated = []
738
  for i,v in enumerate(all_3d_vertices):
739
+ if v[2] > prune_tall_thr or v[2] < prune_short_thr:
740
  isolated.append(i)
741
  break
742
 
 
749
  return all_3d_vertices, connections_3d
750
 
751
  for i,v in enumerate(all_3d_vertices):
752
+ if v[2] > prune_tall_thr or v[2] < prune_short_thr:
753
  isolated.append(i)
754
  break
755
 
756
  return all_3d_vertices, connections_3d
757
 
758
+ def clean_gest(gest_seg_np):
759
+ bg_color = np.array(gestalt_color_mapping['unclassified'])
760
+ bg_mask = cv2.inRange(gest_seg_np, bg_color-10, bg_color+10)
761
+ fg_mask = cv2.bitwise_not(bg_mask)
762
+ if fg_mask.sum() > 0:
763
+ output = cv2.connectedComponentsWithStats(fg_mask, 8, cv2.CV_32S)
764
+ (numLabels, labels, stats, centroids) = output
765
+ sizes = stats[1:, -1] # Get the areas (skip the first entry which is the background)
766
+ max_area = max(sizes)
767
+ max_label = np.where(sizes == max_area)[0] + 1 # Add 1 to get the actual label
768
+ # mask out anything that doesn't belong to the largest component
769
+ gest_seg_np[labels != max_label] = bg_color
770
+
771
+ return gest_seg_np
772
 
773
  def predict(entry, visualize=False, prune_dist_thr=600, depth_scale=2.5, ) -> Tuple[np.ndarray, List[int]]:
774
  good_entry = convert_entry_to_human_readable(entry)
 
789
  elif i==2: # only visualize view 0,1
790
  continue
791
 
792
+ if i!=3:
793
  continue
794
  '''
795
  ade_seg = ade.resize(depth.size)
796
  ade_seg_np = np.array(ade_seg).astype(np.uint8)
797
  gest_seg = gest.resize(depth.size)
798
  gest_seg_np = np.array(gest_seg).astype(np.uint8)
799
+ gest_seg_np = clean_gest(gest_seg_np)
800
+ #print('gest_seg_np=', gest_seg_np.shape)
801
+
802
  # Metric3D
803
  depth_np = np.array(depth) / depth_scale # / 2.5 # 2.5 is the scale estimation coefficient # don't use 2.5...
804
  #vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np, edge_th = 20.)
 
835
  all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image, 150)
836
  #print(f'after merge, {len(all_3d_vertices)} 3d vertices and {len(connections_3d)} 3d connections')
837
  #all_3d_vertices_clean, connections_3d_clean = prune_not_connected(all_3d_vertices, connections_3d)
838
+ all_3d_vertices, connections_3d = prune_tall_short(all_3d_vertices, connections_3d, 1000, 0)
839
+
840
 
841
  if len(all_3d_vertices)>35:
842
  all_3d_vertices, connections_3d = prune_not_connected(all_3d_vertices, connections_3d)
 
844
  all_3d_vertices_clean, connections_3d_clean = prune_far(all_3d_vertices, connections_3d, prune_dist_thr=prune_dist_thr)
845
  else:
846
  all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d
847
+
848
  if i%2:
849
  connections_3d_clean = [connections_3d_clean[-1], connections_3d_clean[0]]
850
  else:
851
  connections_3d_clean = [connections_3d_clean[0]]
852
  #print('connections_3d_clean=', connections_3d_clean)
853
+ ''''''
854
  #all_3d_vertices_clean, connections_3d_clean = all_3d_vertices, connections_3d # don't prune -> cost:2.0
855
  #print(f'after pruning, {len(all_3d_vertices_clean)} 3d clean vertices and {len(connections_3d_clean)} 3d clean connections')
856
  if (len(all_3d_vertices_clean) < 2) or len(connections_3d_clean) < 1: