dennistrujillo commited on
Commit
4a657e9
·
verified ·
1 Parent(s): 9840e47

Swapped csv bb input for dialogue box

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -45,13 +45,12 @@ def medsam_inference(medsam_model, img, box, H, W, target_size):
45
  return mask_resized
46
 
47
  # Function for visualizing images with masks
48
- def visualize(images, masks, box):
49
- fig, ax = plt.subplots(len(images), 2, figsize=(10, 5*len(images)))
50
- for i, (image, mask) in enumerate(zip(images, masks)):
51
- ax[i, 0].imshow(image, cmap='gray')
52
- ax[i, 0].add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], edgecolor="red", facecolor="none"))
53
- ax[i, 1].imshow(image, cmap='gray')
54
- ax[i, 1].imshow(mask, alpha=0.5, cmap="jet")
55
  plt.tight_layout()
56
  buf = io.BytesIO()
57
  plt.savefig(buf, format='png')
@@ -60,34 +59,32 @@ def visualize(images, masks, box):
60
  return buf
61
 
62
  # Main function for Gradio app
63
- def process_images(csv_file, dicom_file):
64
- bounding_boxes = load_bounding_boxes(csv_file)
65
  image, H, W = load_dicom_image(dicom_file)
66
 
67
  # Initialize MedSAM model
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
- medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
70
  medsam_model = medsam_model.to(device)
71
  medsam_model.eval()
72
 
73
- masks = []
74
- boxes = []
75
- for index, row in bounding_boxes.iterrows():
76
- box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']]
77
- mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
78
- masks.append(mask)
79
- boxes.append(box)
80
 
81
- visualizations = visualize([image] * len(masks), masks, boxes)
82
- return visualizations.getvalue()
83
 
84
  # Set up Gradio interface
85
  iface = gr.Interface(
86
  fn=process_images,
87
  inputs=[
88
- gr.File(label="CSV File"),
89
- gr.File(label="DICOM File")],
90
- outputs="plot"
 
 
 
 
91
  )
92
 
93
  iface.launch()
 
45
  return mask_resized
46
 
47
  # Function for visualizing images with masks
48
+ def visualize(image, mask, box):
49
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
50
+ ax[0].imshow(image, cmap='gray')
51
+ ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
52
+ ax[1].imshow(image, cmap='gray')
53
+ ax[1].imshow(mask, alpha=0.5, cmap="jet")
 
54
  plt.tight_layout()
55
  buf = io.BytesIO()
56
  plt.savefig(buf, format='png')
 
59
  return buf
60
 
61
  # Main function for Gradio app
62
+ def process_images(dicom_file, x_min, y_min, x_max, y_max):
 
63
  image, H, W = load_dicom_image(dicom_file)
64
 
65
  # Initialize MedSAM model
66
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
67
+ medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH) # Ensure the correct path
68
  medsam_model = medsam_model.to(device)
69
  medsam_model.eval()
70
 
71
+ box = [x_min, y_min, x_max, y_max]
72
+ mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
 
 
 
 
 
73
 
74
+ visualization = visualize(image, mask, box)
75
+ return visualization.getvalue() # Returning the byte stream
76
 
77
  # Set up Gradio interface
78
  iface = gr.Interface(
79
  fn=process_images,
80
  inputs=[
81
+ gr.inputs.File(label="DICOM File"),
82
+ gr.inputs.Number(label="X min"),
83
+ gr.inputs.Number(label="Y min"),
84
+ gr.inputs.Number(label="X max"),
85
+ gr.inputs.Number(label="Y max")
86
+ ],
87
+ outputs=gr.outputs.Image(type="plot")
88
  )
89
 
90
  iface.launch()