Spaces:
Sleeping
Sleeping
dennistrujillo
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -60,6 +60,7 @@ def visualize(image, mask, box):
|
|
60 |
buf.seek(0)
|
61 |
return buf
|
62 |
|
|
|
63 |
# Main function for Gradio app
|
64 |
def process_images(file, x_min, y_min, x_max, y_max):
|
65 |
image, H, W = load_image(file)
|
@@ -67,24 +68,20 @@ def process_images(file, x_min, y_min, x_max, y_max):
|
|
67 |
# Check if CUDA is available, and set the device accordingly
|
68 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
|
70 |
-
#
|
71 |
-
model_checkpoint_path = "medsam_vit_b.pth"
|
72 |
-
if device == 'cpu':
|
73 |
-
checkpoint = torch.load(model_checkpoint_path, map_location=torch.device('cpu'))
|
74 |
-
else:
|
75 |
-
checkpoint = torch.load(model_checkpoint_path)
|
76 |
|
77 |
-
# Create model instance
|
78 |
-
medsam_model = sam_model_registry['vit_b']
|
79 |
|
80 |
-
#
|
81 |
-
|
|
|
82 |
|
83 |
-
medsam_model = medsam_model.to(device)
|
84 |
medsam_model.eval()
|
85 |
|
86 |
box = [x_min, y_min, x_max, y_max]
|
87 |
-
mask = medsam_inference(medsam_model, image, box, H, W, H, device)
|
88 |
|
89 |
visualization = visualize(image, mask, box)
|
90 |
return visualization.getvalue() # Returning the byte stream
|
|
|
60 |
buf.seek(0)
|
61 |
return buf
|
62 |
|
63 |
+
# Main function for Gradio app
|
64 |
# Main function for Gradio app
|
65 |
def process_images(file, x_min, y_min, x_max, y_max):
|
66 |
image, H, W = load_image(file)
|
|
|
68 |
# Check if CUDA is available, and set the device accordingly
|
69 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
70 |
|
71 |
+
# Define the checkpoint path
|
72 |
+
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
# Create the model instance and load the checkpoint
|
75 |
+
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
|
76 |
|
77 |
+
# If running on CPU, map the model to CPU
|
78 |
+
if device == 'cpu':
|
79 |
+
medsam_model = medsam_model.to(torch.device('cpu'))
|
80 |
|
|
|
81 |
medsam_model.eval()
|
82 |
|
83 |
box = [x_min, y_min, x_max, y_max]
|
84 |
+
mask = medsam_inference(medsam_model, image, box, H, W, H, device) # Pass device to the inference function
|
85 |
|
86 |
visualization = visualize(image, mask, box)
|
87 |
return visualization.getvalue() # Returning the byte stream
|