Spaces:
Sleeping
Sleeping
File size: 4,646 Bytes
54b4471 80d0374 54b4471 16fa719 54b4471 34cc7b2 54b4471 80d0374 55223b8 80d0374 9ba0bac 55223b8 54b4471 16fa719 201e3ec 16fa719 201e3ec 4b69b97 201e3ec 80d0374 201e3ec 54b4471 4a657e9 54b4471 34cc7b2 80d0374 34cc7b2 54b4471 80d0374 af6805d 80d0374 34cc7b2 16fa719 201e3ec 5bb4fec 42aa5e0 5bb4fec a2aae29 201e3ec 42aa5e0 201e3ec 5bb4fec 201e3ec 5bb4fec 201e3ec 54b4471 5bb4fec 201e3ec 34cc7b2 54b4471 80d0374 34cc7b2 80d0374 34cc7b2 03725aa 34cc7b2 80d0374 34cc7b2 80d0374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
import nrrd
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
def load_nrrd(file_path):
data, header = nrrd.read(file_path)
# If the data is 3D, take the middle slice
if len(data.shape) == 3:
middle_slice = data.shape[2] // 2
img = data[:, :, middle_slice]
else:
img = data
# Normalize the image to 0-255 range
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch=box_torch.reshape(1,4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def process_nrrd(nrrd_file, points):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess NRRD file
image, H, W = load_nrrd(nrrd_file.name)
if len(points) >= 6:
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
else:
raise ValueError("Insufficient data for bounding box coordinates.")
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Generate image embedding
with torch.no_grad():
img_embed = medsam_model.image_encoder(image_tensor)
# Calculate resized box coordinates
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
# Visualization
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
return visualization
# Set up Gradio interface
iface = gr.Interface(
fn=process_nrrd,
inputs=[
gr.File(label="NRRD File"),
gr.JSON(label="Bounding Box Coordinates")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM for NRRD Files",
description="Upload an NRRD file and provide bounding box coordinates for processing."
)
# Launch the interface
iface.launch() |