niulx commited on
Commit
62c1dec
1 Parent(s): c1234b6

Update segment.py

Browse files
Files changed (1) hide show
  1. segment.py +8 -0
segment.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  import argparse
12
  import matplotlib
13
  import gradio as gr
 
14
 
15
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
16
  if type(image_path) is str:
@@ -52,6 +53,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
52
  if torch.min(segmentation) == 0:
53
  mask = segmentation==0
54
  mask = mask.cpu().detach().numpy() # [512,512] bool
 
 
55
  segment_label = "rest"
56
  color = viridis(0)
57
  label = f"{segment_label}-{0}"
@@ -65,6 +68,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
65
  if torch.min(segmentation) != 0:
66
  segment_id -= 1
67
  mask = mask.cpu().detach().numpy() # [512,512] bool
 
 
68
  mask_np_list.append(mask)
69
  segment_label = model.config.id2label[segment['label_id']]
70
  instances_counter[segment['label_id']] += 1
@@ -76,6 +81,9 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
76
  label_list.append(label)
77
  else:
78
  mask = np.full(segmentation.shape, True)
 
 
 
79
  segment_label = "all"
80
  mask_np_list.append(mask)
81
  color = viridis(0)
 
11
  import argparse
12
  import matplotlib
13
  import gradio as gr
14
+ import cv2
15
 
16
  def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
17
  if type(image_path) is str:
 
53
  if torch.min(segmentation) == 0:
54
  mask = segmentation==0
55
  mask = mask.cpu().detach().numpy() # [512,512] bool
56
+ print(mask.shape)
57
+ mask = cv2.resize(mask,(512,512))
58
  segment_label = "rest"
59
  color = viridis(0)
60
  label = f"{segment_label}-{0}"
 
68
  if torch.min(segmentation) != 0:
69
  segment_id -= 1
70
  mask = mask.cpu().detach().numpy() # [512,512] bool
71
+ print(mask.shape)
72
+ mask = cv2.resize(mask,(512,512))
73
  mask_np_list.append(mask)
74
  segment_label = model.config.id2label[segment['label_id']]
75
  instances_counter[segment['label_id']] += 1
 
81
  label_list.append(label)
82
  else:
83
  mask = np.full(segmentation.shape, True)
84
+ print(mask.shape)
85
+ mask = cv2.resize(mask,(512,512))
86
+
87
  segment_label = "all"
88
  mask_np_list.append(mask)
89
  color = viridis(0)