Spaces:
Sleeping
Sleeping
from typing import List | |
import os | |
import cv2 | |
import supervision as sv | |
import numpy as np | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from PIL import Image | |
# Definici贸n de la clase SamAutomaticMaskGenerator | |
class SamAutomaticMaskGenerator: | |
def __init__(self, sam_pipeline): | |
self.sam_pipeline = sam_pipeline | |
def generate(self, image_rgb): | |
# Convertir el array de NumPy a PIL Image | |
image_pil = Image.fromarray(image_rgb) | |
outputs = self.sam_pipeline(image_pil, points_per_batch=32) | |
mask = np.array(outputs['masks'], dtype=np.uint8) | |
return mask | |
# Configuraci贸n del modelo SAM | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
sam_pipeline = pipeline( | |
task="mask-generation", | |
model="facebook/sam-vit-large", | |
device=DEVICE | |
) | |
EXAMPLES = [ | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg"], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg"] | |
] | |
mask_generator = SamAutomaticMaskGenerator(sam_pipeline) | |
# Funci贸n para procesar y anotar la imagen | |
def process_image(image_pil): | |
# Convertir PIL Image a numpy array para procesamiento | |
image_rgb = np.array(image_pil) | |
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) | |
# Generar la m谩scara y anotar la imagen | |
sam_result = mask_generator.generate(image_rgb) | |
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
detections = sv.Detections.from_sam(sam_result=sam_result) | |
annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections) | |
# Convertir de nuevo a formato RGB y luego a PIL Image para Gradio | |
annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) | |
annotated_image_pil = Image.fromarray(annotated_image_rgb) | |
return image_pil, annotated_image_pil | |
# Construcci贸n de la interfaz Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# SAM - Segmentaci贸n de Im谩genes") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Cargar Imagen") | |
submit_button = gr.Button("Segmentar") | |
with gr.Column(): | |
original_image = gr.Image(type="pil", label="Imagen Original") | |
segmented_image = gr.Image(type="pil", label="Imagen Segmentada") | |
submit_button.click( | |
process_image, | |
inputs=input_image, | |
outputs=[original_image, segmented_image] | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[original_image, segmented_image], | |
cache_examples=False, | |
run_on_click=True | |
) | |
demo.launch(debug=True) | |