veb-101 commited on
Commit
06ba000
·
1 Parent(s): c438991

check updated app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -28,7 +28,7 @@ def get_model(*, model_path, num_classes):
28
 
29
  @torch.inference_mode()
30
  def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
31
- shape_H_W = input_image.size
32
  input_tensor = preprocess_fn(input_image)
33
  input_tensor = input_tensor.unsqueeze(0).to(device)
34
 
@@ -70,20 +70,32 @@ if __name__ == "__main__":
70
  ]
71
  )
72
 
73
- with gr.Blocks(title="Medical Image Segmentation") as demo:
74
- gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- with gr.Row():
77
- img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
78
- img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
79
 
80
- section_btn = gr.Button("Generate Predictions")
81
 
82
- section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
83
 
84
- images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
85
- examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
86
 
87
- gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
88
 
89
  demo.launch()
 
28
 
29
  @torch.inference_mode()
30
  def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
31
+ shape_H_W = input_image.size[::-1]
32
  input_tensor = preprocess_fn(input_image)
33
  input_tensor = input_tensor.unsqueeze(0).to(device)
34
 
 
70
  ]
71
  )
72
 
73
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
74
+ examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
75
+ demo = gr.Interface(
76
+ fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
77
+ inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
78
+ outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
79
+ examples=examples,
80
+ cache_examples=False,
81
+ allow_flagging="never",
82
+ title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
83
+ )
84
+
85
+ # with gr.Blocks(title="Medical Image Segmentation") as demo:
86
+ # gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
87
 
88
+ # with gr.Row():
89
+ # img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
90
+ # img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
91
 
92
+ # section_btn = gr.Button("Generate Predictions")
93
 
94
+ # section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
95
 
96
+ # images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
97
+ # examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
98
 
99
+ # gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
100
 
101
  demo.launch()