File size: 3,890 Bytes
d1bffba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from turtle import title
import os 
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
import torch 
import cv2 
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
from skimage.measure import label, regionprops

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
classes = list()

def create_mask(image,image_mask,alpha=0.7):
    mask = np.zeros_like(image)
    # copy your image_mask to all dimensions (i.e. colors) of your image
    for i in range(3): 
        mask[:,:,i] = image_mask.copy()
    # apply the mask to your image
    overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
    return overlay_image

def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
    bbox = np.asarray(bbox)/model_shape
    y1,y2 = bbox[::2] *orig_image_shape[0]
    x1,x2 = bbox[1::2]*orig_image_shape[1]
    return [int(y1),int(x1),int(y2),int(x2)]

def detect_using_clip(image,prompts=[],threshould=0.4):
    model_detections = dict()
    predicted_images = dict()
    inputs = processor(
        text=prompts,
        images=[image] * len(prompts),
        padding="max_length",
        return_tensors="pt",
    )
    with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)

    detection = outputs.logits[0]  # Assuming class index 0
    for i,prompt in enumerate(prompts):
        predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
        predicted_image = np.where(predicted_image>threshould,255,0)
        # extract countours from the image
        lbl_0 = label(predicted_image)
        props = regionprops(lbl_0)
        prompt = prompt.lower()
        model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
        predicted_images[prompt]= cv2.resize(predicted_image,image.shape[:2])
    return model_detections , predicted_images

def visualize_images(image,detections,predicted_image,prompt):
    alpha = 0.7
    H,W = image.shape[:2]
    prompt = prompt.lower()
    image_copy = image.copy()
    mask_image = create_mask(image=image_copy,image_mask=predicted_image)
    
    if prompt not in detections.keys():
        print("prompt not in query ..")
        return image_copy
    for bbox in detections[prompt]:
        cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
        cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
    final_image = cv2.addWeighted(image_copy,alpha,mask_image,1-alpha,0)
    return final_image

def shot(image, labels_text,selected_categoty):
    prompts = labels_text.split(',')
    prompts = list(map(lambda x: x.strip(),prompts))

    model_detections,predicted_images  = detect_using_clip(image,prompts=prompts)

    category_image = visualize_images(image=image,detections=model_detections,predicted_image=predicted_images,prompt=selected_categoty)
    return category_image

iface = gr.Interface(fn=shot,
                    inputs = ["image","text","text"],
                    outputs = "image",
                    description ="Add an Image and list of category to be detected separated by commas",
                    title = "Zero-shot Image Classification with Prompt ",
                    examples=[
                        ["images/room.jpg","bed, table, plant, light, window",'plant'],
                        ["images/image2.png","banner, building,door, sign","sign"]
                        ],
                    # allow_flagging=False, 
                    # analytics_enabled=False,
                )
iface.launch()