veb-101 commited on
Commit
ca278b3
·
1 Parent(s): a313617

Gradio app fixed

Browse files
Files changed (3) hide show
  1. Segformer_best_state_dict.ckpt +0 -3
  2. app.py +11 -26
  3. requirements.txt +1 -2
Segformer_best_state_dict.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:800bb5ba3fff6c5539542cc6d9548da73dbc1a35c0dc686f0bade3b3c6c5746c
3
- size 256373829
 
 
 
 
app.py CHANGED
@@ -18,7 +18,7 @@ class Configs:
18
  IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
19
  MEAN: tuple = (0.485, 0.456, 0.406)
20
  STD: tuple = (0.229, 0.224, 0.225)
21
- MODEL_PATH: str = "nvidia/segformer-b4-finetuned-ade-512-512" # os.path.join(os.getcwd(), "segformer_trained_weights")
22
 
23
 
24
  def get_model(*, model_path, num_classes):
@@ -47,11 +47,8 @@ if __name__ == "__main__":
47
  class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
48
 
49
  DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
50
- CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
51
 
52
  model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
53
- model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
54
-
55
  model.to(DEVICE)
56
  model.eval()
57
  _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
@@ -64,29 +61,17 @@ if __name__ == "__main__":
64
  ]
65
  )
66
 
67
- images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
68
- examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
69
- demo = gr.Interface(
70
- fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
71
- inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
72
- outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
73
- examples=examples,
74
- cache_examples=False,
75
- allow_flagging="never",
76
- title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
77
- )
78
-
79
- # with gr.Blocks(title="Medical Image Segmentation") as demo:
80
- # gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
81
- # with gr.Row():
82
- # img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
83
- # img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
84
 
85
- # section_btn = gr.Button("Generate Predictions")
86
- # section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
87
 
88
- # images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
89
- # examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
90
- # gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
91
 
92
  demo.launch()
 
18
  IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
19
  MEAN: tuple = (0.485, 0.456, 0.406)
20
  STD: tuple = (0.229, 0.224, 0.225)
21
+ MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
22
 
23
 
24
  def get_model(*, model_path, num_classes):
 
47
  class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
48
 
49
  DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
 
50
 
51
  model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
 
 
52
  model.to(DEVICE)
53
  model.eval()
54
  _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
 
61
  ]
62
  )
63
 
64
+ with gr.Blocks(title="Medical Image Segmentation") as demo:
65
+ gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
66
+ with gr.Row():
67
+ img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
68
+ img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ section_btn = gr.Button("Generate Predictions")
71
+ section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
72
 
73
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
74
+ examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
75
+ gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
76
 
77
  demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  --find-links https://download.pytorch.org/whl/torch_stable.html
2
  torch==2.0.0+cpu
3
  torchvision==0.15.0
4
- transformers==4.30.2
5
- gradio
 
1
  --find-links https://download.pytorch.org/whl/torch_stable.html
2
  torch==2.0.0+cpu
3
  torchvision==0.15.0
4
+ transformers==4.30.2