import gradio as gr
import os
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import Tuple, Dict
from timeit import default_timer as timer
from skimage import io, transform
import base64
import json
import torch.nn.functional as F
from model import create_sam_model
# 1.Setup variables
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
# 2.Model preparation and load save weights
medsam_model = create_sam_model(model_type,checkpoint,device)
# 3.Predict fn
def show_mask(mask, ax):
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
size=(H, W),
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
def predict(img) -> Tuple[Dict, float]:
"""Transforms and performs a prediction on img and returns prediction and time taken.
# Start the timer
start_time = timer()
# Transform the target image and add a batch dimension
img_np = np.array(img)
# Convierte de BGR a RGB si es necesario
if img_np.shape[-1] == 3: # Asegura que sea una imagen en color
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
if len(img_np.shape) == 2:
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
img_3c = img_np
H, W, _ = img_3c.shape
# %% image preprocessing
img_1024 = transform.resize(
img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
img_1024 = (img_1024 - img_1024.min()) / np.clip(
img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
) # normalize to [0, 1], (H, W, 3)
# convert the shape to (3, H, W)
img_1024_tensor = (
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Put model into evaluation mode and turn on inference mode
with torch.inference_mode():
image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
# define the inputbox
input_box = np.array([[125, 275, 190, 350]])
# transfer box_np t0 1024x1024 scale
box_1024 = input_box.astype(int) / np.array([W, H, W, H])* 1024
medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
pred_time = round(timer() - start_time, 5)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
show_box(input_box[0], ax[0])
ax[0].set_title("Input Image and Bounding Box")
show_mask(medsam_seg, ax[1])
show_box(input_box[0], ax[1])
ax[1].set_title("MedSAM Segmentation")
# Calculate the prediction time
image_embedding = image_embedding.cpu().numpy().tobytes()
# Serialize the response data to JSON format
serialized_data = json.dumps([base64.b64encode(image_embedding).decode('ascii')])
# Return the prediction dictionary and prediction time
return fig, pred_time,serialized_data
# 4. Gradio app
# Create title, description and article strings
title = "MedSam"
description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder."
article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil"), # what are the inputs?
outputs=[gr.Plot(label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)"),
gr.JSON(label="Embedding Image")], # our fn has two outputs, therefore we have two outputs
# Launch the demo!
demo.launch(debug=False, # print errors locally?
share=True) # generate a publically shareable URL?