Gradio app fixed
Browse files- Segformer_best_state_dict.ckpt +0 -3
- app.py +11 -26
- requirements.txt +1 -2
Segformer_best_state_dict.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:800bb5ba3fff6c5539542cc6d9548da73dbc1a35c0dc686f0bade3b3c6c5746c
|
3 |
-
size 256373829
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -18,7 +18,7 @@ class Configs:
|
|
18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
20 |
STD: tuple = (0.229, 0.224, 0.225)
|
21 |
-
MODEL_PATH: str =
|
22 |
|
23 |
|
24 |
def get_model(*, model_path, num_classes):
|
@@ -47,11 +47,8 @@ if __name__ == "__main__":
|
|
47 |
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
|
48 |
|
49 |
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
50 |
-
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
|
51 |
|
52 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
53 |
-
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
|
54 |
-
|
55 |
model.to(DEVICE)
|
56 |
model.eval()
|
57 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
@@ -64,29 +61,17 @@ if __name__ == "__main__":
|
|
64 |
]
|
65 |
)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
|
73 |
-
examples=examples,
|
74 |
-
cache_examples=False,
|
75 |
-
allow_flagging="never",
|
76 |
-
title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
|
77 |
-
)
|
78 |
-
|
79 |
-
# with gr.Blocks(title="Medical Image Segmentation") as demo:
|
80 |
-
# gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
81 |
-
# with gr.Row():
|
82 |
-
# img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
83 |
-
# img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
84 |
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
demo.launch()
|
|
|
18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
20 |
STD: tuple = (0.229, 0.224, 0.225)
|
21 |
+
MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
|
22 |
|
23 |
|
24 |
def get_model(*, model_path, num_classes):
|
|
|
47 |
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
|
48 |
|
49 |
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
50 |
|
51 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
|
|
|
|
52 |
model.to(DEVICE)
|
53 |
model.eval()
|
54 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
|
|
61 |
]
|
62 |
)
|
63 |
|
64 |
+
with gr.Blocks(title="Medical Image Segmentation") as demo:
|
65 |
+
gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
66 |
+
with gr.Row():
|
67 |
+
img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
68 |
+
img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
section_btn = gr.Button("Generate Predictions")
|
71 |
+
section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
|
72 |
|
73 |
+
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
74 |
+
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
75 |
+
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
|
76 |
|
77 |
demo.launch()
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
torch==2.0.0+cpu
|
3 |
torchvision==0.15.0
|
4 |
-
transformers==4.30.2
|
5 |
-
gradio
|
|
|
1 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
torch==2.0.0+cpu
|
3 |
torchvision==0.15.0
|
4 |
+
transformers==4.30.2
|
|