Spaces:
Sleeping
Sleeping
dennistrujillo
commited on
Updated to allow for nrrd uploads
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
|
|
3 |
import numpy as np
|
4 |
import pydicom
|
5 |
import os
|
|
|
6 |
from skimage import transform
|
7 |
import torch
|
8 |
from segment_anything import sam_model_registry
|
@@ -12,12 +13,18 @@ import torch.nn.functional as F
|
|
12 |
import io
|
13 |
from gradio_image_prompter import ImagePrompter
|
14 |
|
15 |
-
def
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
else:
|
20 |
-
img =
|
|
|
|
|
|
|
21 |
|
22 |
# Convert grayscale to 3-channel RGB by replicating channels
|
23 |
if len(img.shape) == 2: # Grayscale image (height, width)
|
@@ -45,7 +52,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
45 |
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
46 |
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
47 |
multimask_output=False,
|
48 |
-
|
49 |
|
50 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
51 |
|
@@ -59,7 +66,6 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
59 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
60 |
return medsam_seg
|
61 |
|
62 |
-
# Function for visualizing images with masks
|
63 |
def visualize(image, mask, box):
|
64 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
65 |
ax[0].imshow(image, cmap='gray')
|
@@ -68,30 +74,24 @@ def visualize(image, mask, box):
|
|
68 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
69 |
plt.tight_layout()
|
70 |
|
71 |
-
# Convert matplotlib figure to a PIL Image
|
72 |
buf = io.BytesIO()
|
73 |
fig.savefig(buf, format='png')
|
74 |
-
plt.close(fig)
|
75 |
buf.seek(0)
|
76 |
pil_img = Image.open(buf)
|
77 |
|
78 |
return pil_img
|
79 |
|
80 |
-
|
81 |
-
def process_images(img_dict):
|
82 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
83 |
|
84 |
-
# Load and preprocess
|
85 |
-
|
86 |
-
|
87 |
if len(points) >= 6:
|
88 |
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
|
89 |
else:
|
90 |
raise ValueError("Insufficient data for bounding box coordinates.")
|
91 |
-
image, H, W = img, img.shape[0], img.shape[1] #
|
92 |
-
if len(image.shape) == 2:
|
93 |
-
image = np.repeat(image[:, :, None], 3, axis=-1)
|
94 |
-
H, W, _ = image.shape
|
95 |
|
96 |
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
|
97 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
@@ -120,16 +120,17 @@ def process_images(img_dict):
|
|
120 |
|
121 |
# Set up Gradio interface
|
122 |
iface = gr.Interface(
|
123 |
-
fn=
|
124 |
inputs=[
|
125 |
-
|
|
|
126 |
],
|
127 |
outputs=[
|
128 |
gr.Image(type="pil", label="Processed Image")
|
129 |
],
|
130 |
-
title="ROI Selection with MEDSAM",
|
131 |
-
description="Upload an
|
132 |
)
|
133 |
|
134 |
# Launch the interface
|
135 |
-
iface.launch()
|
|
|
3 |
import numpy as np
|
4 |
import pydicom
|
5 |
import os
|
6 |
+
import nrrd
|
7 |
from skimage import transform
|
8 |
import torch
|
9 |
from segment_anything import sam_model_registry
|
|
|
13 |
import io
|
14 |
from gradio_image_prompter import ImagePrompter
|
15 |
|
16 |
+
def load_nrrd(file_path):
|
17 |
+
data, header = nrrd.read(file_path)
|
18 |
+
|
19 |
+
# If the data is 3D, take the middle slice
|
20 |
+
if len(data.shape) == 3:
|
21 |
+
middle_slice = data.shape[2] // 2
|
22 |
+
img = data[:, :, middle_slice]
|
23 |
else:
|
24 |
+
img = data
|
25 |
+
|
26 |
+
# Normalize the image to 0-255 range
|
27 |
+
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
|
28 |
|
29 |
# Convert grayscale to 3-channel RGB by replicating channels
|
30 |
if len(img.shape) == 2: # Grayscale image (height, width)
|
|
|
52 |
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
53 |
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
54 |
multimask_output=False,
|
55 |
+
)
|
56 |
|
57 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
58 |
|
|
|
66 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
67 |
return medsam_seg
|
68 |
|
|
|
69 |
def visualize(image, mask, box):
|
70 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
71 |
ax[0].imshow(image, cmap='gray')
|
|
|
74 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
75 |
plt.tight_layout()
|
76 |
|
|
|
77 |
buf = io.BytesIO()
|
78 |
fig.savefig(buf, format='png')
|
79 |
+
plt.close(fig)
|
80 |
buf.seek(0)
|
81 |
pil_img = Image.open(buf)
|
82 |
|
83 |
return pil_img
|
84 |
|
85 |
+
def process_nrrd(nrrd_file, points):
|
|
|
86 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
87 |
|
88 |
+
# Load and preprocess NRRD file
|
89 |
+
image, H, W = load_nrrd(nrrd_file.name)
|
90 |
+
|
91 |
if len(points) >= 6:
|
92 |
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
|
93 |
else:
|
94 |
raise ValueError("Insufficient data for bounding box coordinates.")
|
|
|
|
|
|
|
|
|
95 |
|
96 |
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
|
97 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
|
|
120 |
|
121 |
# Set up Gradio interface
|
122 |
iface = gr.Interface(
|
123 |
+
fn=process_nrrd,
|
124 |
inputs=[
|
125 |
+
gr.File(label="NRRD File"),
|
126 |
+
gr.JSON(label="Bounding Box Coordinates")
|
127 |
],
|
128 |
outputs=[
|
129 |
gr.Image(type="pil", label="Processed Image")
|
130 |
],
|
131 |
+
title="ROI Selection with MEDSAM for NRRD Files",
|
132 |
+
description="Upload an NRRD file and provide bounding box coordinates for processing."
|
133 |
)
|
134 |
|
135 |
# Launch the interface
|
136 |
+
iface.launch()
|