from typing import List import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from transformers import pipeline, CLIPProcessor, CLIPModel #************ #Variables globales MARKDOWN = """ #SAM """ EXAMPLES = [ ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5], ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5], ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5], ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6], ] MIN_AREA_THRESHOLD = 0.01 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAM_GENERATOR = pipeline( task = "mask-generation", model = "facebook/sam-vit-large", device = DEVICE ) SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator( color = sv.Color.red(), color_lookup = sv.ColorLookup.INDEX ) SOLID_MASK_ANNOTATOR = sv.MaskAnnotator( color = sv.Color.white(), color_lookup = sv.ColorLookup.INDEX, opacity = 1 ) #************ #funciones de trabajo def run_sam(image_rgb_pil : Image.Image ) -> sv.Detections: outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch = 32) mask = np.array(outputs['masks']) return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): gray_color = np.array([ gray_value, gray_value, gray_value ], dtype=np.uint8) return np.where(mask[..., None], image, gray_color) def filter_detections(image_rgb_pil: Image.Image, detections: sv.Detections) -> sv.Detections: img_rgb_numpy = np.array(image_rgb_pil) filtering_mask = [] for xyxy, mask in zip(detections.xyxy, detections.mask): crop = sv.crop_image( image = img_rgb_numpy, xyxy =xyxy ) mask_crop = sv.crop_image( image=mask, xyxy=xyxy ) masked_crop = reverse_mask_image( image=crop, mask=mask_crop ) filtering_mask = np.array( filtering_mask ) return detections[filtering_mask] def inference (image_rgb_pil: Image.Image) -> List[Image.Image]: width, height = image_rgb_pil.size area = width * height detections = run_sam( image_rgb_pil ) detections = detections[ detections.area /area > MIN_AREA_THRESHOLD ] detections = filter_detections( image_rgb_pil=image_rgb_pil, detections=detections, ) blank_image = Image.new("RGB", (width, height), "black") return [ annotate( image_rgb_pil=image_rgb_pil, detections=detections, annotator=SEMITRANSPARENT_MASK_ANNOTATOR), annotate( image_rgb_pil=blank_image, detections=detections, annotator=SOLID_MASK_ANNOTATOR) ] #************ #GRADIO CONSTRUCTION with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image = gr,Image( image_mode = 'RGB', type = 'pil', height = 500 ) submit_button = gr.Button("Pruébalo!!!") gallery = gr.Gallery( label = "Result", object_fit = "scale-down", preview = True ) with gr.Row(): gr.Examples( examples = EXAMPLES, fn = inference, inputs = [ input_image, prompt_text, confidence_slider ], outputs = [gallery], cache_examples = True, run_on_click = True ) submit_button.click( inference, inputs = [ input_image, prompt_text, confidence_slider ], outputs = gallery ) demo.launch( debug = True, show_error = True )