MedSAMTest / app.py
dennistrujillo's picture
restored inference functionality
f0b3d8c verified
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
import nrrd # Add this import for NRRD file support
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path) # Add this condition for NRRD files
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch=box_torch.reshape(1,4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
# Function for visualizing images with masks
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
# Convert matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to release memory
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
# Main function for Gradio app
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess image
print(img_dict)
img = img_dict['image']
points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
if len(points) >= 6:
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
else:
raise ValueError("Insufficient data for bounding box coordinates.")
image, H, W = img, img.shape[0], img.shape[1]
if len(image.shape) == 2:
image = np.repeat(image[:, :, None], 3, axis=-1)
H, W, _ = image.shape
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Generate image embedding
with torch.no_grad():
img_embed = medsam_model.image_encoder(image_tensor)
# Calculate resized box coordinates
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
# Visualization
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
return visualization
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select regions of interest for processing."
)
# Launch the interface
iface.launch()