dennistrujillo commited on
Commit
201e3ec
·
verified ·
1 Parent(s): 1c3d5f7

Integrated MedSAM Model with Full Pipeline for Image Segmentation

Browse files
Files changed (1) hide show
  1. app.py +47 -38
app.py CHANGED
@@ -24,34 +24,35 @@ def load_image(file_path):
24
  H, W = img.shape[:2]
25
  return img, H, W
26
 
27
- # MedSAM inference function
28
- def medsam_inference(medsam_model, img, box, H, W, target_size, device):
29
- # Assuming the model expects 1024x1024 input
30
- expected_model_input_size = 1024
31
-
32
- # Resize image to expected model input size
33
- img_resized = transform.resize(img, (expected_model_input_size, expected_model_input_size), anti_aliasing=True)
34
-
35
- # Ensure the image is in the correct shape (H, W, C)
36
- if len(img_resized.shape) == 3 and img_resized.shape[2] == 3:
37
- # Convert to PyTorch tensor and add batch dimension
38
- img_tensor = torch.from_numpy(img_resized.transpose((2, 0, 1))).float().unsqueeze(0).to(device)
39
- else:
40
- raise ValueError("Image must be a 3-channel (RGB) image")
41
-
42
- box_resized = np.array(box) * (target_size / np.array([W, H, W, H]))
43
-
44
- # Model expects box in format (x0, y0, x1, y1)
45
- box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
46
-
47
- # MedSAM inference
48
- img_embed = medsam_model.image_encoder(img_tensor)
49
- mask = medsam_model.predict(img_embed, box_tensor)
50
-
51
- # Post-process mask: resize back to original size
52
- mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W))
53
-
54
- return mask_resized
 
55
 
56
  # Function for visualizing images with masks
57
  def visualize(image, mask, box):
@@ -67,11 +68,13 @@ def visualize(image, mask, box):
67
  buf.seek(0)
68
  return buf
69
 
70
- # Main function for Gradio app
71
  # Main function for Gradio app
72
  def process_images(file, x_min, y_min, x_max, y_max):
 
73
  image, H, W = load_image(file)
74
-
 
 
75
  # Check if CUDA is available, and set the device accordingly
76
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
77
 
@@ -80,18 +83,24 @@ def process_images(file, x_min, y_min, x_max, y_max):
80
 
81
  # Create the model instance and load the checkpoint
82
  medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
 
 
83
 
84
- # If running on CPU, map the model to CPU
85
- if device == 'cpu':
86
- medsam_model = medsam_model.to(torch.device('cpu'))
87
 
88
- medsam_model.eval()
 
 
 
 
 
 
 
89
 
90
- box = [x_min, y_min, x_max, y_max]
91
- mask = medsam_inference(medsam_model, image, box, H, W, H, device) # Pass device to the inference function
92
 
93
- visualization = visualize(image, mask, box)
94
- return visualization.getvalue() # Returning the byte stream
95
 
96
  # Set up Gradio interface
97
  iface = gr.Interface(
 
24
  H, W = img.shape[:2]
25
  return img, H, W
26
 
27
+ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
28
+ box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
29
+ if len(box_torch.shape) == 2:
30
+ box_torch = box_torch[:, None, :] # (B, 1, 4)
31
+
32
+ sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
33
+ points=None,
34
+ boxes=box_torch,
35
+ masks=None,
36
+ )
37
+ low_res_logits, _ = medsam_model.mask_decoder(
38
+ image_embeddings=img_embed, # (B, 256, 64, 64)
39
+ image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
40
+ sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
41
+ dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
42
+ multimask_output=False,
43
+ )
44
+
45
+ low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
46
+
47
+ low_res_pred = F.interpolate(
48
+ low_res_pred,
49
+ size=(H, W),
50
+ mode="bilinear",
51
+ align_corners=False,
52
+ ) # (1, 1, gt.shape)
53
+ low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
54
+ medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
55
+ return medsam_seg
56
 
57
  # Function for visualizing images with masks
58
  def visualize(image, mask, box):
 
68
  buf.seek(0)
69
  return buf
70
 
 
71
  # Main function for Gradio app
72
  def process_images(file, x_min, y_min, x_max, y_max):
73
+ # Load and preprocess image
74
  image, H, W = load_image(file)
75
+ image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
76
+ image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
77
+
78
  # Check if CUDA is available, and set the device accordingly
79
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
80
 
 
83
 
84
  # Create the model instance and load the checkpoint
85
  medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
86
+ medsam_model = medsam_model.to(device)
87
+ medsam_model.eval()
88
 
89
+ # Convert image to tensor and move to the correct device
90
+ image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
 
91
 
92
+ # Generate image embedding
93
+ with torch.no_grad():
94
+ img_embed = medsam_model.image_encoder(image_tensor)
95
+
96
+ # Calculate resized box coordinates and perform inference
97
+ scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
98
+ box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
99
+ mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
100
 
101
+ visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
102
+ return visualization.getvalue()
103
 
 
 
104
 
105
  # Set up Gradio interface
106
  iface = gr.Interface(