dennistrujillo's picture
Updated load_image to convert image to 3 channels if needed
9ba0bac verified
raw
history blame
3.41 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import io
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
# MedSAM inference function
def medsam_inference(medsam_model, img, box, H, W, target_size, device):
# Resize image and box to target size
img_resized = transform.resize(img, (target_size, target_size), anti_aliasing=True)
box_resized = np.array(box) * (target_size / np.array([W, H, W, H]))
# Convert image to PyTorch tensor
img_tensor = torch.from_numpy(img_resized).float().unsqueeze(0).unsqueeze(0).to(device) # Add channel and batch dimension
# Model expects box in format (x0, y0, x1, y1)
box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
# MedSAM inference
img_embed = medsam_model.image_encoder(img_tensor)
mask = medsam_model.predict(img_embed, box_tensor)
# Post-process mask: resize back to original size
mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W))
return mask_resized
# Function for visualizing images with masks
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
return buf
# Main function for Gradio app
# Main function for Gradio app
def process_images(file, x_min, y_min, x_max, y_max):
image, H, W = load_image(file)
# Check if CUDA is available, and set the device accordingly
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Define the checkpoint path
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
# Create the model instance and load the checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
# If running on CPU, map the model to CPU
if device == 'cpu':
medsam_model = medsam_model.to(torch.device('cpu'))
medsam_model.eval()
box = [x_min, y_min, x_max, y_max]
mask = medsam_inference(medsam_model, image, box, H, W, H, device) # Pass device to the inference function
visualization = visualize(image, mask, box)
return visualization.getvalue() # Returning the byte stream
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
gr.Number(label="X min"),
gr.Number(label="Y min"),
gr.Number(label="X max"),
gr.Number(label="Y max")
],
outputs="plot"
)
iface.launch()