CLIPSeg2 / app.py
aryswisnu's picture
Update app.py
6da5017
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import torch
import numpy as np
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt, threhsold):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
return mask
def get_masks(prompts, img, threhsold):
prompts = prompts.split(",")
masks = []
for prompt in prompts:
mask = process_image(img, prompt, threhsold)
mask = mask > threhsold
fig, ax = plt.subplots()
ax.imshow(img)
ax.imshow(mask, alpha=0.5, cmap="jet")
ax.axis("off")
plt.tight_layout()
masks.append(mask)
return masks
def extract_image(img, pos_prompts, neg_prompts, threshold):
positive_masks = get_masks(pos_prompts, img, threshold)
negative_masks = get_masks(neg_prompts, img, threshold)
# combine masks into one masks, logic OR
pos_mask = np.any(np.stack(positive_masks), axis=0)
neg_mask = np.any(np.stack(negative_masks), axis=0)
final_mask = pos_mask & ~neg_mask
# extract the final image
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
inverse_mask = np.invert(final_mask)
output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
output_image.paste(img, mask=final_mask)
return output_image, final_mask, inverse_mask
title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"
with gr.Blocks() as demo:
gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
gr.Markdown(article)
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil")
positive_prompts = gr.Textbox(
label="Please describe what you want to identify (comma separated)"
)
negative_prompts = gr.Textbox(
label="Please describe what you want to ignore (comma separated)"
)
input_slider_T = gr.Slider(
minimum=0, maximum=1, value=0.4, label="Threshold"
)
btn_process = gr.Button(label="Process")
with gr.Column():
output_image = gr.Image(label="Result")
output_mask = gr.Image(label="Mask")
inverse_mask = gr.Image(label="Inverse")
btn_process.click(
extract_image,
inputs=[
input_image,
positive_prompts,
negative_prompts,
input_slider_T,
],
outputs=[output_image, output_mask, inverse_mask],
api_name="mask"
)
demo.launch()