dennistrujillo commited on
Commit
80d0374
·
verified ·
1 Parent(s): e40e1c6

Updated to allow for nrrd uploads

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -3,6 +3,7 @@ import pandas as pd
3
  import numpy as np
4
  import pydicom
5
  import os
 
6
  from skimage import transform
7
  import torch
8
  from segment_anything import sam_model_registry
@@ -12,12 +13,18 @@ import torch.nn.functional as F
12
  import io
13
  from gradio_image_prompter import ImagePrompter
14
 
15
- def load_image(file_path):
16
- if file_path.endswith(".dcm"):
17
- ds = pydicom.dcmread(file_path)
18
- img = ds.pixel_array
 
 
 
19
  else:
20
- img = np.array(Image.open(file_path))
 
 
 
21
 
22
  # Convert grayscale to 3-channel RGB by replicating channels
23
  if len(img.shape) == 2: # Grayscale image (height, width)
@@ -45,7 +52,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
45
  sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
46
  dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
47
  multimask_output=False,
48
- )
49
 
50
  low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
51
 
@@ -59,7 +66,6 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
59
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
60
  return medsam_seg
61
 
62
- # Function for visualizing images with masks
63
  def visualize(image, mask, box):
64
  fig, ax = plt.subplots(1, 2, figsize=(10, 5))
65
  ax[0].imshow(image, cmap='gray')
@@ -68,30 +74,24 @@ def visualize(image, mask, box):
68
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
69
  plt.tight_layout()
70
 
71
- # Convert matplotlib figure to a PIL Image
72
  buf = io.BytesIO()
73
  fig.savefig(buf, format='png')
74
- plt.close(fig) # Close the figure to release memory
75
  buf.seek(0)
76
  pil_img = Image.open(buf)
77
 
78
  return pil_img
79
 
80
- # Main function for Gradio app
81
- def process_images(img_dict):
82
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
83
 
84
- # Load and preprocess image
85
- img = img_dict['image']
86
- points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
87
  if len(points) >= 6:
88
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
89
  else:
90
  raise ValueError("Insufficient data for bounding box coordinates.")
91
- image, H, W = img, img.shape[0], img.shape[1] #
92
- if len(image.shape) == 2:
93
- image = np.repeat(image[:, :, None], 3, axis=-1)
94
- H, W, _ = image.shape
95
 
96
  image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
97
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
@@ -120,16 +120,17 @@ def process_images(img_dict):
120
 
121
  # Set up Gradio interface
122
  iface = gr.Interface(
123
- fn=process_images,
124
  inputs=[
125
- ImagePrompter(label="Image")
 
126
  ],
127
  outputs=[
128
  gr.Image(type="pil", label="Processed Image")
129
  ],
130
- title="ROI Selection with MEDSAM",
131
- description="Upload an image and select regions of interest for processing."
132
  )
133
 
134
  # Launch the interface
135
- iface.launch()
 
3
  import numpy as np
4
  import pydicom
5
  import os
6
+ import nrrd
7
  from skimage import transform
8
  import torch
9
  from segment_anything import sam_model_registry
 
13
  import io
14
  from gradio_image_prompter import ImagePrompter
15
 
16
+ def load_nrrd(file_path):
17
+ data, header = nrrd.read(file_path)
18
+
19
+ # If the data is 3D, take the middle slice
20
+ if len(data.shape) == 3:
21
+ middle_slice = data.shape[2] // 2
22
+ img = data[:, :, middle_slice]
23
  else:
24
+ img = data
25
+
26
+ # Normalize the image to 0-255 range
27
+ img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
28
 
29
  # Convert grayscale to 3-channel RGB by replicating channels
30
  if len(img.shape) == 2: # Grayscale image (height, width)
 
52
  sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
53
  dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
54
  multimask_output=False,
55
+ )
56
 
57
  low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
58
 
 
66
  medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
67
  return medsam_seg
68
 
 
69
  def visualize(image, mask, box):
70
  fig, ax = plt.subplots(1, 2, figsize=(10, 5))
71
  ax[0].imshow(image, cmap='gray')
 
74
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
75
  plt.tight_layout()
76
 
 
77
  buf = io.BytesIO()
78
  fig.savefig(buf, format='png')
79
+ plt.close(fig)
80
  buf.seek(0)
81
  pil_img = Image.open(buf)
82
 
83
  return pil_img
84
 
85
+ def process_nrrd(nrrd_file, points):
 
86
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
 
88
+ # Load and preprocess NRRD file
89
+ image, H, W = load_nrrd(nrrd_file.name)
90
+
91
  if len(points) >= 6:
92
  x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
93
  else:
94
  raise ValueError("Insufficient data for bounding box coordinates.")
 
 
 
 
95
 
96
  image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
97
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
 
120
 
121
  # Set up Gradio interface
122
  iface = gr.Interface(
123
+ fn=process_nrrd,
124
  inputs=[
125
+ gr.File(label="NRRD File"),
126
+ gr.JSON(label="Bounding Box Coordinates")
127
  ],
128
  outputs=[
129
  gr.Image(type="pil", label="Processed Image")
130
  ],
131
+ title="ROI Selection with MEDSAM for NRRD Files",
132
+ description="Upload an NRRD file and provide bounding box coordinates for processing."
133
  )
134
 
135
  # Launch the interface
136
+ iface.launch()