|
import gradio as gr |
|
import os |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from typing import Tuple, Dict |
|
from timeit import default_timer as timer |
|
from skimage import io, transform |
|
import os |
|
import base64 |
|
import json |
|
|
|
import torch.nn.functional as F |
|
|
|
from model import create_sam_model |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
checkpoint = "sam_vit_b_01ec64.pth" |
|
model_type = "vit_b" |
|
|
|
|
|
medsam_model = create_sam_model(model_type,checkpoint,device) |
|
|
|
|
|
def show_mask(mask, ax): |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def show_points(coords, labels, ax, marker_size=375): |
|
pos_points = coords[labels==1] |
|
neg_points = coords[labels==0] |
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) |
|
|
|
def show_box(box, ax): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) |
|
|
|
@torch.no_grad() |
|
def medsam_inference(medsam_model, img_embed, box_1024, H, W): |
|
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) |
|
if len(box_torch.shape) == 2: |
|
box_torch = box_torch[:, None, :] |
|
|
|
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( |
|
points=None, |
|
boxes=box_torch, |
|
masks=None, |
|
) |
|
low_res_logits, _ = medsam_model.mask_decoder( |
|
image_embeddings=img_embed, |
|
image_pe=medsam_model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
multimask_output=False, |
|
) |
|
|
|
low_res_pred = torch.sigmoid(low_res_logits) |
|
|
|
low_res_pred = F.interpolate( |
|
low_res_pred, |
|
size=(H, W), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
low_res_pred = low_res_pred.squeeze().cpu().numpy() |
|
medsam_seg = (low_res_pred > 0.5).astype(np.uint8) |
|
return medsam_seg |
|
|
|
def predict(img) -> Tuple[Dict, float]: |
|
"""Transforms and performs a prediction on img and returns prediction and time taken. |
|
""" |
|
|
|
start_time = timer() |
|
|
|
|
|
img_np = np.array(img) |
|
|
|
if img_np.shape[-1] == 3: |
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) |
|
|
|
if len(img_np.shape) == 2: |
|
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) |
|
else: |
|
img_3c = img_np |
|
H, W, _ = img_3c.shape |
|
|
|
img_1024 = transform.resize( |
|
img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True |
|
).astype(np.uint8) |
|
img_1024 = (img_1024 - img_1024.min()) / np.clip( |
|
img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None |
|
) |
|
|
|
img_1024_tensor = ( |
|
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) |
|
) |
|
|
|
|
|
medsam_model.eval() |
|
with torch.inference_mode(): |
|
image_embedding = medsam_model.image_encoder(img_1024_tensor) |
|
|
|
input_box = np.array([[125, 275, 190, 350]]) |
|
|
|
box_1024 = input_box / np.array([W, H, W, H]) * 1024 |
|
|
|
medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W) |
|
pred_time = round(timer() - start_time, 5) |
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) |
|
ax[0].imshow(img_3c) |
|
show_box(input_box[0], ax[0]) |
|
ax[0].set_title("Input Image and Bounding Box") |
|
ax[1].imshow(img_3c) |
|
show_mask(medsam_seg, ax[1]) |
|
show_box(input_box[0], ax[1]) |
|
ax[1].set_title("MedSAM Segmentation") |
|
|
|
image_embedding = image_embedding.cpu().numpy().tobytes() |
|
|
|
|
|
serialized_data = json.dumps([base64.b64encode(image_embedding).decode('ascii')]) |
|
|
|
|
|
return fig, pred_time,serialized_data |
|
|
|
|
|
|
|
title = "MedSam" |
|
description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder." |
|
article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ." |
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Plot(label="Predictions"), |
|
gr.Number(label="Prediction time (s)"), |
|
gr.JSON(label="Embedding Image")], |
|
examples=example_list, |
|
title=title, |
|
description=description, |
|
article=article) |
|
|
|
|
|
demo.launch(debug=False, |
|
share=True) |
|
|