EUNSEO56 commited on
Commit
acc161d
1 Parent(s): cf78818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -95,7 +95,7 @@ def draw_plot(pred_img, seg):
95
  return fig
96
 
97
 
98
- def sepia(input_img, title):
99
  input_img = Image.fromarray(input_img)
100
 
101
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -108,22 +108,19 @@ def sepia(input_img, title):
108
  ) # We reverse the shape of `image` because `image.size` returns width and height.
109
  seg = tf.math.argmax(logits, axis=-1)[0]
110
 
111
- unique_labels = np.unique(seg.numpy())
112
- label_images = {}
 
 
 
113
 
114
- for label in unique_labels:
115
- mask = (seg.numpy() == label)
116
- color_mask = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
117
- color_mask[mask, :] = colormap[label]
118
- pred_img = np.array(input_img) * 0.5 + color_mask * 0.5
119
- pred_img = pred_img.astype(np.uint8)
120
- fig = draw_plot(pred_img, seg)
121
- label_images[str(label)] = fig
122
 
123
- # Add title to the figure
124
- plt.suptitle(title, fontsize=20)
125
 
126
- return label_images
127
 
128
  demo = gr.Interface(fn=sepia,
129
  inputs=[gr.Image(shape=(400, 600), label="Input Image"), "text"],
 
95
  return fig
96
 
97
 
98
+ def sepia(input_img):
99
  input_img = Image.fromarray(input_img)
100
 
101
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
108
  ) # We reverse the shape of `image` because `image.size` returns width and height.
109
  seg = tf.math.argmax(logits, axis=-1)[0]
110
 
111
+ color_seg = np.zeros(
112
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
113
+ ) # height, width, 3
114
+ for label, color in enumerate(colormap):
115
+ color_seg[seg.numpy() == label, :] = color
116
 
117
+ # Show image + mask
118
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
119
+ pred_img = pred_img.astype(np.uint8)
 
 
 
 
 
120
 
121
+ fig = draw_plot(pred_img, seg)
122
+ return fig
123
 
 
124
 
125
  demo = gr.Interface(fn=sepia,
126
  inputs=[gr.Image(shape=(400, 600), label="Input Image"), "text"],