import gradio as gr import pandas as pd import numpy as np import pydicom import os import nrrd from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import torch.nn.functional as F import io from gradio_image_prompter import ImagePrompter def load_nrrd(file_path): data, header = nrrd.read(file_path) # If the data is 3D, take the middle slice if len(data.shape) == 3: middle_slice = data.shape[2] // 2 img = data[:, :, middle_slice] else: img = data # Normalize the image to 0-255 range img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8) # Convert grayscale to 3-channel RGB by replicating channels if len(img.shape) == 2: # Grayscale image (height, width) img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3) H, W = img.shape[:2] return img, H, W @torch.no_grad() def medsam_inference(medsam_model, img_embed, box_1024, H, W): box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) if len(box_torch.shape) == 2: box_torch = box_torch[:, None, :] # (B, 1, 4) box_torch=box_torch.reshape(1,4) sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( points=None, boxes=box_torch, masks=None, ) low_res_logits, _ = medsam_model.mask_decoder( image_embeddings=img_embed, # (B, 256, 64, 64) image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) multimask_output=False, ) low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) low_res_pred = F.interpolate( low_res_pred, size=(H, W), mode="bilinear", align_corners=False, ) # (1, 1, gt.shape) low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) medsam_seg = (low_res_pred > 0.5).astype(np.uint8) return medsam_seg def visualize(image, mask, box): fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(image, cmap='gray') ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none")) ax[1].imshow(image, cmap='gray') ax[1].imshow(mask, alpha=0.5, cmap="jet") plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) pil_img = Image.open(buf) return pil_img def process_nrrd(nrrd_file, points): device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load and preprocess NRRD file image, H, W = load_nrrd(nrrd_file.name) if len(points) >= 6: x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4] else: raise ValueError("Insufficient data for bounding box coordinates.") image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8) image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None) image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device) # Initialize the MedSAM model and set the device model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path) medsam_model = medsam_model.to(device) medsam_model.eval() # Generate image embedding with torch.no_grad(): img_embed = medsam_model.image_encoder(image_tensor) # Calculate resized box coordinates scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H]) box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors # Perform inference mask = medsam_inference(medsam_model, img_embed, box_1024, H, W) # Visualization visualization = visualize(image, mask, [x_min, y_min, x_max, y_max]) return visualization # Set up Gradio interface iface = gr.Interface( fn=process_nrrd, inputs=[ gr.File(label="NRRD File"), gr.JSON(label="Bounding Box Coordinates") ], outputs=[ gr.Image(type="pil", label="Processed Image") ], title="ROI Selection with MEDSAM for NRRD Files", description="Upload an NRRD file and provide bounding box coordinates for processing." ) # Launch the interface iface.launch()