dennistrujillo's picture
Updated to allow for nrrd uploads
80d0374 verified
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
import nrrd
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
def load_nrrd(file_path):
data, header = nrrd.read(file_path)
# If the data is 3D, take the middle slice
if len(data.shape) == 3:
middle_slice = data.shape[2] // 2
img = data[:, :, middle_slice]
else:
img = data
# Normalize the image to 0-255 range
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch=box_torch.reshape(1,4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def process_nrrd(nrrd_file, points):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess NRRD file
image, H, W = load_nrrd(nrrd_file.name)
if len(points) >= 6:
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
else:
raise ValueError("Insufficient data for bounding box coordinates.")
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Generate image embedding
with torch.no_grad():
img_embed = medsam_model.image_encoder(image_tensor)
# Calculate resized box coordinates
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
# Visualization
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
return visualization
# Set up Gradio interface
iface = gr.Interface(
fn=process_nrrd,
inputs=[
gr.File(label="NRRD File"),
gr.JSON(label="Bounding Box Coordinates")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM for NRRD Files",
description="Upload an NRRD file and provide bounding box coordinates for processing."
)
# Launch the interface
iface.launch()