Spaces:
Runtime error
Runtime error
from typing import List | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from transformers import pipeline, CLIPProcessor, CLIPModel | |
MARKDOWN = """ | |
# Segment Anything Model + MetaCLIP | |
This is the demo for a Open Vocabulary Image Segmentation using | |
[Segment Anything Model](https://github.com/facebookresearch/segment-anything) and | |
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo. | |
""" | |
EXAMPLES = [ | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog"], | |
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building"], | |
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket"], | |
] | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SAM_GENERATOR = pipeline( | |
task="mask-generation", | |
model="facebook/sam-vit-large", | |
device=DEVICE) | |
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE) | |
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") | |
MASK_ANNOTATOR = sv.MaskAnnotator( | |
color=sv.Color.red(), | |
color_lookup=sv.ColorLookup.INDEX) | |
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections: | |
outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32) | |
mask = np.array(outputs['masks']) | |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray: | |
inputs = CLIP_PROCESSOR( | |
text=text, | |
images=image_rgb_pil, | |
return_tensors="pt", | |
padding=True | |
).to(DEVICE) | |
outputs = CLIP_MODEL(**inputs) | |
probs = outputs.logits_per_image.softmax(dim=1) | |
return probs.detach().cpu().numpy() | |
def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) | |
return np.where(mask[..., None], image, gray_color) | |
def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Image: | |
img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1] | |
annotated_bgr_image = MASK_ANNOTATOR.annotate( | |
scene=img_bgr_numpy, detections=detections) | |
return Image.fromarray(annotated_bgr_image[:, :, ::-1]) | |
def filter_detections( | |
image_rgb_pil: Image.Image, | |
detections: sv.Detections, | |
prompt: str | |
) -> sv.Detections: | |
img_rgb_numpy = np.array(image_rgb_pil) | |
text = [f"a picture of {prompt}", "a picture of background"] | |
filtering_mask = [] | |
for xyxy, mask in zip(detections.xyxy, detections.mask): | |
crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy) | |
mask_crop = sv.crop_image(image=mask, xyxy=xyxy) | |
masked_crop = reverse_mask_image(image=crop, mask=mask_crop) | |
masked_crop_pil = Image.fromarray(masked_crop) | |
probs = run_clip(image_rgb_pil=masked_crop_pil, text=text) | |
lass_index = np.argmax(probs) | |
filtering_mask.append(lass_index == 0) | |
filtering_mask = np.array(filtering_mask) | |
return detections[filtering_mask] | |
def inference(image_rgb_pil: Image.Image, prompt: str) -> List[Image.Image]: | |
width, height = image_rgb_pil.size | |
area = width * height | |
detections = run_sam(image_rgb_pil) | |
detections = detections[detections.area / area > 0.01] | |
detections = filter_detections( | |
image_rgb_pil=image_rgb_pil, | |
detections=detections, | |
prompt=prompt) | |
return [ | |
annotate(image_rgb_pil=image_rgb_pil, detections=detections), | |
annotate(image_rgb_pil=Image.new("RGB", (width, height), "black"), detections=detections) | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(image_mode='RGB', type='pil', height=500) | |
prompt_text = gr.Textbox(label="Prompt", value="dog") | |
submit_button = gr.Button("Submit") | |
gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True) | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=inference, | |
inputs=[input_image, prompt_text], | |
outputs=[gallery], | |
cache_examples=True, | |
run_on_click=True | |
) | |
submit_button.click( | |
inference, | |
inputs=[input_image, prompt_text], | |
outputs=gallery) | |
demo.launch(debug=False, show_error=True) | |