Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load SAM 2 model | |
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot_masks(image, masks): | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(image) | |
for mask in masks: | |
masked = np.ma.masked_where(mask == 0, mask) | |
ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet')) | |
ax.axis('off') | |
plt.close() | |
return fig2img(fig) | |
def segment_everything(input_image): | |
try: | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
input_image = Image.fromarray(input_image).convert("RGB") | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
predictor.set_image(input_image) | |
# Use 'everything' prompt | |
masks, _, _ = predictor.predict([]) | |
# Plot the results | |
result_image = plot_masks(input_image, masks) | |
return result_image, f"Segmented everything in the image. Found {len(masks)} objects." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_everything, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="SAM 2 Everything Segmentation", | |
description="Upload an image to segment all objects using SAM 2." | |
) | |
# Launch the interface | |
iface.launch() |