Spaces:
Sleeping
Sleeping
from typing import List | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from transformers import pipeline, CLIPProcessor, CLIPModel | |
#************ | |
#Variables globales | |
MARKDOWN = """ | |
#SAM | |
""" | |
EXAMPLES = [ | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6], | |
] | |
MIN_AREA_THRESHOLD = 0.01 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SAM_GENERATOR = pipeline( | |
task = "mask-generation", | |
model = "facebook/sam-vit-large", | |
device = DEVICE | |
) | |
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator( | |
color = sv.Color.red(), | |
color_lookup = sv.ColorLookup.INDEX | |
) | |
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator( | |
color = sv.Color.white(), | |
color_lookup = sv.ColorLookup.INDEX, | |
opacity = 1 | |
) | |
#************ | |
#funciones de trabajo | |
def run_sam(image_rgb_pil : Image.Image ) -> sv.Detections: | |
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch = 32) | |
mask = np.array(outputs['masks']) | |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
gray_color = np.array([ | |
gray_value, | |
gray_value, | |
gray_value | |
], dtype=np.uint8) | |
return np.where(mask[..., None], image, gray_color) | |
def filter_detections(image_rgb_pil: Image.Image, detections: sv.Detections) -> sv.Detections: | |
img_rgb_numpy = np.array(image_rgb_pil) | |
filtering_mask = [] | |
for xyxy, mask in zip(detections.xyxy, detections.mask): | |
crop = sv.crop_image( | |
image = img_rgb_numpy, | |
xyxy =xyxy | |
) | |
mask_crop = sv.crop_image( | |
image=mask, | |
xyxy=xyxy | |
) | |
masked_crop = reverse_mask_image( | |
image=crop, | |
mask=mask_crop | |
) | |
filtering_mask = np.array( | |
filtering_mask | |
) | |
return detections[filtering_mask] | |
def inference (image_rgb_pil: Image.Image) -> List[Image.Image]: | |
width, height = image_rgb_pil.size | |
area = width * height | |
detections = run_sam( | |
image_rgb_pil | |
) | |
detections = detections[ detections.area /area > MIN_AREA_THRESHOLD ] | |
detections = filter_detections( | |
image_rgb_pil=image_rgb_pil, | |
detections=detections, | |
) | |
blank_image = Image.new("RGB", (width, height), "black") | |
return [ | |
annotate( | |
image_rgb_pil=image_rgb_pil, | |
detections=detections, | |
annotator=SEMITRANSPARENT_MASK_ANNOTATOR), | |
annotate( | |
image_rgb_pil=blank_image, | |
detections=detections, | |
annotator=SOLID_MASK_ANNOTATOR) | |
] | |
#************ | |
#GRADIO CONSTRUCTION | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr,Image( | |
image_mode = 'RGB', | |
type = 'pil', | |
height = 500 | |
) | |
submit_button = gr.Button("Pruébalo!!!") | |
gallery = gr.Gallery( | |
label = "Result", | |
object_fit = "scale-down", | |
preview = True | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples = EXAMPLES, | |
fn = inference, | |
inputs = [ | |
input_image, | |
prompt_text, | |
confidence_slider | |
], | |
outputs = [gallery], | |
cache_examples = True, | |
run_on_click = True | |
) | |
submit_button.click( | |
inference, | |
inputs = [ | |
input_image, | |
prompt_text, | |
confidence_slider | |
], | |
outputs = gallery | |
) | |
demo.launch( debug = True, show_error = True ) |