Spaces:
Running
Running
kushagra124
commited on
Commit
•
d1bffba
1
Parent(s):
297686d
adding app with CLIP image segmentation
Browse files- app.py +93 -0
- images/image2.png +0 -0
- images/room.jpg +0 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from turtle import title
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import pipeline
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
|
10 |
+
from skimage.measure import label, regionprops
|
11 |
+
|
12 |
+
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
13 |
+
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
14 |
+
classes = list()
|
15 |
+
|
16 |
+
def create_mask(image,image_mask,alpha=0.7):
|
17 |
+
mask = np.zeros_like(image)
|
18 |
+
# copy your image_mask to all dimensions (i.e. colors) of your image
|
19 |
+
for i in range(3):
|
20 |
+
mask[:,:,i] = image_mask.copy()
|
21 |
+
# apply the mask to your image
|
22 |
+
overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
|
23 |
+
return overlay_image
|
24 |
+
|
25 |
+
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
|
26 |
+
bbox = np.asarray(bbox)/model_shape
|
27 |
+
y1,y2 = bbox[::2] *orig_image_shape[0]
|
28 |
+
x1,x2 = bbox[1::2]*orig_image_shape[1]
|
29 |
+
return [int(y1),int(x1),int(y2),int(x2)]
|
30 |
+
|
31 |
+
def detect_using_clip(image,prompts=[],threshould=0.4):
|
32 |
+
model_detections = dict()
|
33 |
+
predicted_images = dict()
|
34 |
+
inputs = processor(
|
35 |
+
text=prompts,
|
36 |
+
images=[image] * len(prompts),
|
37 |
+
padding="max_length",
|
38 |
+
return_tensors="pt",
|
39 |
+
)
|
40 |
+
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
|
41 |
+
outputs = model(**inputs)
|
42 |
+
preds = outputs.logits.unsqueeze(1)
|
43 |
+
|
44 |
+
detection = outputs.logits[0] # Assuming class index 0
|
45 |
+
for i,prompt in enumerate(prompts):
|
46 |
+
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
47 |
+
predicted_image = np.where(predicted_image>threshould,255,0)
|
48 |
+
# extract countours from the image
|
49 |
+
lbl_0 = label(predicted_image)
|
50 |
+
props = regionprops(lbl_0)
|
51 |
+
prompt = prompt.lower()
|
52 |
+
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
|
53 |
+
predicted_images[prompt]= cv2.resize(predicted_image,image.shape[:2])
|
54 |
+
return model_detections , predicted_images
|
55 |
+
|
56 |
+
def visualize_images(image,detections,predicted_image,prompt):
|
57 |
+
alpha = 0.7
|
58 |
+
H,W = image.shape[:2]
|
59 |
+
prompt = prompt.lower()
|
60 |
+
image_copy = image.copy()
|
61 |
+
mask_image = create_mask(image=image_copy,image_mask=predicted_image)
|
62 |
+
|
63 |
+
if prompt not in detections.keys():
|
64 |
+
print("prompt not in query ..")
|
65 |
+
return image_copy
|
66 |
+
for bbox in detections[prompt]:
|
67 |
+
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
|
68 |
+
cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
|
69 |
+
final_image = cv2.addWeighted(image_copy,alpha,mask_image,1-alpha,0)
|
70 |
+
return final_image
|
71 |
+
|
72 |
+
def shot(image, labels_text,selected_categoty):
|
73 |
+
prompts = labels_text.split(',')
|
74 |
+
prompts = list(map(lambda x: x.strip(),prompts))
|
75 |
+
|
76 |
+
model_detections,predicted_images = detect_using_clip(image,prompts=prompts)
|
77 |
+
|
78 |
+
category_image = visualize_images(image=image,detections=model_detections,predicted_image=predicted_images,prompt=selected_categoty)
|
79 |
+
return category_image
|
80 |
+
|
81 |
+
iface = gr.Interface(fn=shot,
|
82 |
+
inputs = ["image","text","text"],
|
83 |
+
outputs = "image",
|
84 |
+
description ="Add an Image and list of category to be detected separated by commas",
|
85 |
+
title = "Zero-shot Image Classification with Prompt ",
|
86 |
+
examples=[
|
87 |
+
["images/room.jpg","bed, table, plant, light, window",'plant'],
|
88 |
+
["images/image2.png","banner, building,door, sign","sign"]
|
89 |
+
],
|
90 |
+
# allow_flagging=False,
|
91 |
+
# analytics_enabled=False,
|
92 |
+
)
|
93 |
+
iface.launch()
|
images/image2.png
ADDED
images/room.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
sentencepiece
|
4 |
+
huggingface_hub
|
5 |
+
numpy
|
6 |
+
scikit-image
|
7 |
+
opencv-python
|
8 |
+
Pillow
|
9 |
+
requests
|
10 |
+
urllib3<2
|
11 |
+
git+https://github.com/facebookresearch/segment-anything.git
|