dennistrujillo commited on
Commit
a2aae29
·
verified ·
1 Parent(s): de86f89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -60,6 +60,7 @@ def visualize(image, mask, box):
60
  buf.seek(0)
61
  return buf
62
 
 
63
  # Main function for Gradio app
64
  def process_images(file, x_min, y_min, x_max, y_max):
65
  image, H, W = load_image(file)
@@ -67,24 +68,20 @@ def process_images(file, x_min, y_min, x_max, y_max):
67
  # Check if CUDA is available, and set the device accordingly
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
 
70
- # Load the model with appropriate device mapping
71
- model_checkpoint_path = "medsam_vit_b.pth"
72
- if device == 'cpu':
73
- checkpoint = torch.load(model_checkpoint_path, map_location=torch.device('cpu'))
74
- else:
75
- checkpoint = torch.load(model_checkpoint_path)
76
 
77
- # Create model instance
78
- medsam_model = sam_model_registry['vit_b'] # Create model instance without checkpoint
79
 
80
- # Load the state dictionary into the model
81
- medsam_model.load_state_dict(checkpoint) # Load the saved weights
 
82
 
83
- medsam_model = medsam_model.to(device)
84
  medsam_model.eval()
85
 
86
  box = [x_min, y_min, x_max, y_max]
87
- mask = medsam_inference(medsam_model, image, box, H, W, H, device) # Assuming target size is the same as the image height
88
 
89
  visualization = visualize(image, mask, box)
90
  return visualization.getvalue() # Returning the byte stream
 
60
  buf.seek(0)
61
  return buf
62
 
63
+ # Main function for Gradio app
64
  # Main function for Gradio app
65
  def process_images(file, x_min, y_min, x_max, y_max):
66
  image, H, W = load_image(file)
 
68
  # Check if CUDA is available, and set the device accordingly
69
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
 
71
+ # Define the checkpoint path
72
+ model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
 
 
 
 
73
 
74
+ # Create the model instance and load the checkpoint
75
+ medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
76
 
77
+ # If running on CPU, map the model to CPU
78
+ if device == 'cpu':
79
+ medsam_model = medsam_model.to(torch.device('cpu'))
80
 
 
81
  medsam_model.eval()
82
 
83
  box = [x_min, y_min, x_max, y_max]
84
+ mask = medsam_inference(medsam_model, image, box, H, W, H, device) # Pass device to the inference function
85
 
86
  visualization = visualize(image, mask, box)
87
  return visualization.getvalue() # Returning the byte stream