Vishakaraj commited on
Commit
a4443af
1 Parent(s): 5bdcaaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -52,6 +52,11 @@ DATASET_COLORMAPS = {
52
  "ade20k": colormaps.ADE20K_COLORMAP,
53
  "voc2012": colormaps.VOC2012_COLORMAP,
54
  }
 
 
 
 
 
55
 
56
  model = init_segmentor(cfg)
57
  load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
@@ -91,11 +96,11 @@ def create_segmenter(cfg, backbone_model):
91
 
92
 
93
  def render_segmentation(segmentation_logits, dataset):
94
- colormap = DATASET_COLORMAPS[dataset]
95
  colormap_array = np.array(colormap, dtype=np.uint8)
96
  segmentation_logits += 1
97
- segmentation_values = colormap_array[segmentation_logits]
98
- segmentation_values = segmentation_values[:, :, ::-1]
 
99
  unique_labels = np.unique(segmentation_logits)
100
 
101
  colormap_array = colormap_array[unique_labels]
@@ -107,7 +112,7 @@ def render_segmentation(segmentation_logits, dataset):
107
  for idx, color in enumerate(colormap_array):
108
  color_box = np.zeros((20, 20, 3), dtype=np.uint8)
109
  color_box[:, :] = color
110
- # color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
111
  _, img_data = cv2.imencode(".jpg", color_box)
112
  img_base64 = base64.b64encode(img_data).decode("utf-8")
113
  img_data_uri = f"data:image/jpg;base64,{img_base64}"
@@ -115,14 +120,15 @@ def render_segmentation(segmentation_logits, dataset):
115
 
116
  html_output += "</div>"
117
 
118
- return Image.fromarray(segmentation_values), html_output
119
 
120
 
121
  def predict(image_file):
122
  array = np.array(image_file)[:, :, ::-1] # BGR
123
  segmentation_logits = inference_segmentor(model, array)[0]
 
124
  segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
125
- return np.array(segmented_image), html_output
126
 
127
  description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
128
 
@@ -130,10 +136,10 @@ demo = gr.Interface(
130
  title="Semantic Segmentation - DinoV2",
131
  fn=predict,
132
  inputs=gr.inputs.Image(),
133
- outputs=[gr.outputs.Image(type="numpy"), gr.outputs.HTML()],
134
  examples=["example_1.jpg", "example_2.jpg"],
135
  cache_examples=False,
136
  description=description,
137
  )
138
 
139
- demo.launch()
 
52
  "ade20k": colormaps.ADE20K_COLORMAP,
53
  "voc2012": colormaps.VOC2012_COLORMAP,
54
  }
55
+ colormap = DATASET_COLORMAPS["ade20k"]
56
+ flattened = np.array(colormap).flatten()
57
+ zeros = np.zeros(768)
58
+ zeros[:flattened.shape[0]] = flattened
59
+ colorMap = list(zeros.astype('uint8'))
60
 
61
  model = init_segmentor(cfg)
62
  load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
 
96
 
97
 
98
  def render_segmentation(segmentation_logits, dataset):
 
99
  colormap_array = np.array(colormap, dtype=np.uint8)
100
  segmentation_logits += 1
101
+ segmented_image = Image.fromarray(segmentation_logits)
102
+ segmented_image.putpalette(colorMap)
103
+
104
  unique_labels = np.unique(segmentation_logits)
105
 
106
  colormap_array = colormap_array[unique_labels]
 
112
  for idx, color in enumerate(colormap_array):
113
  color_box = np.zeros((20, 20, 3), dtype=np.uint8)
114
  color_box[:, :] = color
115
+ color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
116
  _, img_data = cv2.imencode(".jpg", color_box)
117
  img_base64 = base64.b64encode(img_data).decode("utf-8")
118
  img_data_uri = f"data:image/jpg;base64,{img_base64}"
 
120
 
121
  html_output += "</div>"
122
 
123
+ return segmented_image, html_output
124
 
125
 
126
  def predict(image_file):
127
  array = np.array(image_file)[:, :, ::-1] # BGR
128
  segmentation_logits = inference_segmentor(model, array)[0]
129
+ segmentation_logits = segmentation_logits.astype(np.uint8)
130
  segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
131
+ return segmented_image, html_output
132
 
133
  description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
134
 
 
136
  title="Semantic Segmentation - DinoV2",
137
  fn=predict,
138
  inputs=gr.inputs.Image(),
139
+ outputs=[gr.outputs.Image(type="pil"), gr.outputs.HTML()],
140
  examples=["example_1.jpg", "example_2.jpg"],
141
  cache_examples=False,
142
  description=description,
143
  )
144
 
145
+ demo.launch()