dennistrujillo commited on
Commit
42aa5e0
·
verified ·
1 Parent(s): f7c280e

Fixed Model Loading Issue for CPU-Only Environments

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -64,9 +64,9 @@ def visualize(image, mask, box):
64
  def process_images(file, x_min, y_min, x_max, y_max):
65
  image, H, W = load_image(file)
66
 
67
- # Initialize MedSAM model
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
-
70
  # Load the model with appropriate device mapping
71
  model_checkpoint_path = "medsam_vit_b.pth"
72
  if device == 'cpu':
@@ -74,7 +74,12 @@ def process_images(file, x_min, y_min, x_max, y_max):
74
  else:
75
  checkpoint = torch.load(model_checkpoint_path)
76
 
77
- medsam_model = sam_model_registry['vit_b'](checkpoint=checkpoint)
 
 
 
 
 
78
  medsam_model = medsam_model.to(device)
79
  medsam_model.eval()
80
 
 
64
  def process_images(file, x_min, y_min, x_max, y_max):
65
  image, H, W = load_image(file)
66
 
67
+ # Check if CUDA is available, and set the device accordingly
68
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+
70
  # Load the model with appropriate device mapping
71
  model_checkpoint_path = "medsam_vit_b.pth"
72
  if device == 'cpu':
 
74
  else:
75
  checkpoint = torch.load(model_checkpoint_path)
76
 
77
+ # Create model instance
78
+ medsam_model = sam_model_registry['vit_b']() # Create model instance without checkpoint
79
+
80
+ # Load the state dictionary into the model
81
+ medsam_model.load_state_dict(checkpoint) # Load the saved weights
82
+
83
  medsam_model = medsam_model.to(device)
84
  medsam_model.eval()
85