Spaces:
Running
Running
from turtle import title | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from PIL import Image | |
import torch | |
import cv2 | |
from matplotlib import pyplot as plt | |
from segmentation_mask_overlay import overlay_masks | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig | |
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
classes = list() | |
def create_rgb_mask(mask): | |
color = tuple(np.random.choice(range(0,256), size=3)) | |
gray_3_channel = cv2.merge((mask, mask, mask)) | |
gray_3_channel[mask==255] = color | |
return gray_3_channel.astype(np.uint8) | |
def detect_using_clip(image,prompts=[],threshould=0.4): | |
predicted_masks = list() | |
inputs = processor( | |
text=prompts, | |
images=[image] * len(prompts), | |
padding="max_length", | |
return_tensors="pt", | |
) | |
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation | |
outputs = model(**inputs) | |
preds = outputs.logits.unsqueeze(1) | |
for i,prompt in enumerate(prompts): | |
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() | |
predicted_image = np.where(predicted_image>threshould,255,0) | |
predicted_masks.append(predicted_image) | |
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks] | |
return bool_masks | |
def visualize_images(image,predicted_images,brightness=15,contrast=1.8): | |
alpha = 0.7 | |
image_resize = cv2.resize(image,(352,352)) | |
resize_image_copy = image_resize.copy() | |
# for mask_image in predicted_images: | |
# resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10) | |
return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness) | |
def shot(alpha,beta,image,labels_text): | |
if "," in labels_text: | |
prompts = labels_text.split(',') | |
else: | |
prompts = [labels_text] | |
prompts = list(map(lambda x: x.strip(),prompts)) | |
mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)] | |
cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1] | |
resize_image = cv2.resize(image,(352,352)) | |
predicted_images = detect_using_clip(image,prompts=prompts) | |
category_image = overlay_masks(resize_image,np.stack(predicted_images,-1),labels=mask_labels,colors=cmap,alpha=alpha,beta=beta) | |
return category_image | |
iface = gr.Interface(fn=shot, | |
inputs = [ | |
gr.Slider(0.1, 1, value=0.4, step=0.1 , label="alpha", info="Choose between 0.1 to 1"), | |
gr.Slider(0.1, 1, value=1, step=0.1, label="beta", info="Choose between 0.1 to 1"), | |
"image", | |
"text" | |
], | |
outputs = "image", | |
description ="Add an Image and labels to be detected separated by commas(atleast 2)", | |
title = "Zero-shot Image Segmentation with Prompt", | |
examples=[ | |
[0.4,1,"images/room.jpg","chair, plant , flower pot , white cabinet , paintings , decorative plates , books"], | |
[0.4,1,"images/seats.jpg","door,table,chairs"], | |
[0.3,0.8,"images/vegetables.jpg","carrot,white radish,brinjal,basket,potato"], | |
[0.5,1,"images/room2.jpg","door, plants, dog, coffe table, table lamp, carpet, door"] | |
], | |
# allow_flagging=False, | |
# analytics_enabled=False, | |
) | |
iface.launch() | |