File size: 4,008 Bytes
c438991 b6db8ed c438991 06ba000 c438991 b6db8ed c438991 b6db8ed c438991 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import numpy as np
import gradio as gr
from glob import glob
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torchvision.transforms as TF
from transformers import SegformerForSemanticSegmentation
@dataclass
class Configs:
NUM_CLASSES: int = 4 # including background.
CLASSES: tuple = ("Large bowel", "Small bowel", "Stomach")
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
MEAN: tuple = (0.485, 0.456, 0.406)
STD: tuple = (0.229, 0.224, 0.225)
MODEL_PATH: str = "nvidia/segformer-b4-finetuned-ade-512-512" # os.path.join(os.getcwd(), "segformer_trained_weights")
def get_model(*, model_path, num_classes):
model = SegformerForSemanticSegmentation.from_pretrained(model_path, num_labels=num_classes, ignore_mismatched_sizes=True)
return model
@torch.inference_mode()
def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
shape_H_W = input_image.size[::-1]
input_tensor = preprocess_fn(input_image)
input_tensor = input_tensor.unsqueeze(0).to(device)
# Generate predictions
outputs = model(pixel_values=input_tensor.to(device), return_dict=True)
predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False)
preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy()
seg_info = [(preds_argmax == idx, class_name) for idx, class_name in enumerate(Configs.CLASSES, 1)]
return (input_image, seg_info)
if __name__ == "__main__":
# Create a mapping of class ID to RGB value.
id2color = {
0: (0, 0, 0), # background pixel
1: (0, 0, 255), # Stomach
2: (0, 255, 0), # Small bowel
3: (255, 0, 0), # large bowel
}
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
preprocess = TF.Compose(
[
TF.Resize(size=Configs.IMAGE_SIZE[::-1]),
TF.ToTensor(),
TF.Normalize(Configs.MEAN, Configs.STD, inplace=True),
]
)
# images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
# examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
# demo = gr.Interface(
# fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
# inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
# outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
# examples=examples,
# cache_examples=False,
# allow_flagging="never",
# title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
# )
with gr.Blocks(title="Medical Image Segmentation") as demo:
gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
with gr.Row():
img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
section_btn = gr.Button("Generate Predictions")
section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
demo.launch()
|