import gradio as gr import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from sam2.sam2_image_predictor import SAM2ImagePredictor # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load SAM 2 model predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot_masks(image, masks): fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) for mask in masks: masked = np.ma.masked_where(mask == 0, mask) ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet')) ax.axis('off') plt.close() return fig2img(fig) def segment_everything(input_image): try: if input_image is None: return None, "Please upload an image before submitting." input_image = Image.fromarray(input_image).convert("RGB") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(input_image) # Use 'everything' prompt masks, _, _ = predictor.predict([]) # Plot the results result_image = plot_masks(input_image, masks) return result_image, f"Segmented everything in the image. Found {len(masks)} objects." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=segment_everything, inputs=[ gr.Image(type="numpy", label="Upload an image") ], outputs=[ gr.Image(type="pil", label="Segmented Image"), gr.Textbox(label="Status") ], title="SAM 2 Everything Segmentation", description="Upload an image to segment all objects using SAM 2." ) # Launch the interface iface.launch()