ronnief1 commited on
Commit
ffbd8d1
1 Parent(s): cfa0071

Update app2.py

Browse files
Files changed (1) hide show
  1. app2.py +100 -21
app2.py CHANGED
@@ -226,6 +226,63 @@ def model_infer(img_name):
226
  break
227
 
228
  return image_vis, gt_mask, pr_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  PAGE_TITLE = "Polyp Segmentation"
230
 
231
  def file_selector(folder_path='.'):
@@ -242,33 +299,55 @@ def file_selector_ui():
242
  return filename
243
 
244
  def file_upload(folder_path='.'):
245
- filenames = os.listdir(folder_path)
246
  folder_path = './test/test/images'
247
- uploaded_file = st.file_uploader("Choose a file")
248
- filename = os.path.join(folder_path, uploaded_file.name)
249
- printname = list(filename)
250
- printname[filename.rfind('\\')] = '/'
251
- st.write('You selected`%s`' % ''.join(printname))
252
- return filename
 
 
253
 
254
 
255
  def main():
256
  st.set_page_config(page_title=PAGE_TITLE, layout="wide")
257
  st.title(PAGE_TITLE)
258
- image_path = file_selector_ui()
259
- # image_path = file_upload()
260
- image_path = os.path.abspath(image_path)
261
- to_infer = image_path[image_path.rfind("\\") + 1:]
262
-
263
- if os.path.isfile(image_path) is True:
264
- _, file_extension = os.path.splitext(image_path)
265
- if file_extension == ".jpg":
266
- image_vis, gt_mask, pr_mask = model_infer(to_infer)
267
- visualize(
268
- image=image_vis,
269
- #ground_truth_mask=gt_mask,
270
- predicted_mask=pr_mask
271
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  if __name__ == "__main__":
274
  main()
 
226
  break
227
 
228
  return image_vis, gt_mask, pr_mask
229
+
230
+ def model_infer_new(img_name):
231
+
232
+ model = smp.UnetPlusPlus(
233
+ encoder_name=ENCODER,
234
+ encoder_weights=ENCODER_WEIGHTS,
235
+ encoder_depth=5,
236
+ decoder_channels=(256, 128, 64, 32, 16),
237
+ classes=len(CLASSES),
238
+ activation=ACTIVATION,
239
+ decoder_attention_type=None,
240
+ )
241
+
242
+
243
+ model.load_state_dict(torch.load('best.pth', map_location=torch.device('cpu'))['model_state_dict'])
244
+ model.eval()
245
+
246
+ test_dataset = Dataset(
247
+ img_name,
248
+ img_name,
249
+ augmentation=get_validation_augmentation(),
250
+ preprocessing=get_preprocessing(preprocessing_fn),
251
+ classes=CLASSES,
252
+ single_file=True
253
+ )
254
+
255
+ test_dataloader = DataLoader(test_dataset)
256
+
257
+ loaders = {"infer": test_dataloader}
258
+
259
+ runner = SupervisedRunner()
260
+ logits = []
261
+ f = 0
262
+ for prediction in runner.predict_loader(model=model, loader=loaders['infer'],cpu=True):
263
+ if f < 3:
264
+ logits.append(prediction['logits'])
265
+ f = f + 1
266
+ else:
267
+ break
268
+
269
+ threshold = 0.5
270
+ break_at = 1
271
+
272
+ for i, (input, output) in enumerate(zip(
273
+ test_dataset, logits)):
274
+ image, mask = input
275
+
276
+ image_vis = image.transpose(1, 2, 0)
277
+ pr_mask = (output[0].numpy() > threshold).astype('uint8')[0]
278
+ i = i + 1
279
+ if i >= break_at:
280
+ break
281
+
282
+ return image_vis, pr_mask
283
+
284
+
285
+
286
  PAGE_TITLE = "Polyp Segmentation"
287
 
288
  def file_selector(folder_path='.'):
 
299
  return filename
300
 
301
  def file_upload(folder_path='.'):
302
+ # filenames = os.listdir(folder_path)
303
  folder_path = './test/test/images'
304
+ uploaded_file = st.file_uploader("Choose a file")
305
+ if uploaded_file is not None:
306
+ filename = os.path.join(folder_path, uploaded_file.name)
307
+ printname = list(filename)
308
+ printname[filename.rfind('\\')] = '/'
309
+ st.write('You selected`%s`' % ''.join(printname))
310
+ return filename
311
+
312
 
313
 
314
  def main():
315
  st.set_page_config(page_title=PAGE_TITLE, layout="wide")
316
  st.title(PAGE_TITLE)
317
+ choice = st.radio(
318
+ "Upload your own image or infer on a pre-existing image?",
319
+ ('Pre-existing', 'Own'))
320
+
321
+
322
+ if choice == 'Pre-existing':
323
+ image_path = file_selector_ui()
324
+ image_path = os.path.abspath(image_path)
325
+ to_infer = image_path[image_path.rfind("\\") + 1:]
326
+
327
+ if os.path.isfile(image_path) is True:
328
+ _, file_extension = os.path.splitext(image_path)
329
+ if file_extension == ".jpg":
330
+ image_vis, gt_mask, pr_mask = model_infer(to_infer)
331
+ visualize(
332
+ image=image_vis,
333
+ ground_truth_mask=gt_mask,
334
+ predicted_mask=pr_mask
335
+ )
336
+
337
+ if choice == 'Own':
338
+ image_path = file_upload()
339
+ if image_path is not None:
340
+ image_path = os.path.abspath(image_path)
341
+ to_infer = image_path[image_path.rfind("\\") + 1:]
342
+
343
+ if os.path.isfile(image_path) is True:
344
+ _, file_extension = os.path.splitext(image_path)
345
+ if file_extension == ".jpg":
346
+ image_vis, pr_mask = model_infer_new(to_infer)
347
+ visualize(
348
+ image=image_vis,
349
+ predicted_mask=pr_mask
350
+ )
351
 
352
  if __name__ == "__main__":
353
  main()