EUNSEO56 commited on
Commit
c9ca8aa
·
1 Parent(s): cb1c20a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -95,7 +95,7 @@ def draw_plot(pred_img, seg):
95
  return fig
96
 
97
 
98
- def sepia(input_img):
99
  input_img = Image.fromarray(input_img)
100
 
101
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -108,24 +108,27 @@ def sepia(input_img):
108
  ) # We reverse the shape of `image` because `image.size` returns width and height.
109
  seg = tf.math.argmax(logits, axis=-1)[0]
110
 
111
- color_seg = np.zeros(
112
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
113
- ) # height, width, 3
114
- for label, color in enumerate(colormap):
115
- color_seg[seg.numpy() == label, :] = color
116
 
117
- # Show image + mask
118
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
119
- pred_img = pred_img.astype(np.uint8)
 
 
 
 
 
120
 
121
- fig = draw_plot(pred_img, seg)
122
- return fig
 
 
123
 
124
  demo = gr.Interface(fn=sepia,
125
- inputs=gr.Image(shape=(400, 600)),
126
- outputs=['plot'],
127
- examples=[["side-1.jpg","Image-1"], ["side-2.jpg","Image-2"], ["side-3.jpg","Image-3"], ["side-4.jpg","Image-4"], ["side-5.jpg","Image-5"], ["side-6.jpg","Image-6"]],
128
  allow_flagging='never')
129
 
130
-
131
  demo.launch()
 
95
  return fig
96
 
97
 
98
+ def sepia(input_img, title):
99
  input_img = Image.fromarray(input_img)
100
 
101
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
108
  ) # We reverse the shape of `image` because `image.size` returns width and height.
109
  seg = tf.math.argmax(logits, axis=-1)[0]
110
 
111
+ unique_labels = np.unique(seg.numpy())
112
+ label_images = {}
 
 
 
113
 
114
+ for label in unique_labels:
115
+ mask = (seg.numpy() == label)
116
+ color_mask = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
117
+ color_mask[mask, :] = colormap[label]
118
+ pred_img = np.array(input_img) * 0.5 + color_mask * 0.5
119
+ pred_img = pred_img.astype(np.uint8)
120
+ fig = draw_plot(pred_img, seg)
121
+ label_images[str(label)] = fig
122
 
123
+ # Add title to the figure
124
+ plt.suptitle(title, fontsize=20)
125
+
126
+ return label_images
127
 
128
  demo = gr.Interface(fn=sepia,
129
+ inputs=[gr.Image(shape=(400, 600), label="Input Image"), "text"],
130
+ outputs='image',
131
+ examples=[["side-1.jpg", "Title 1"], ["side-2.jpg", "Title 2"], ["side-3.jpg", "Title 3"]],
132
  allow_flagging='never')
133
 
 
134
  demo.launch()