from turtle import title import os import gradio as gr from transformers import pipeline import numpy as np from PIL import Image import torch import cv2 from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig from skimage.measure import label, regionprops processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") random_images = [] images_dir = 'images/' for idx, images in enumerate(os.listdir(images_dir)): image = os.path.join(images_dir, images) if os.path.isfile(image) and idx < 10: random_images.append(image) def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352): bbox = np.asarray(bbox)/model_shape y1,y2 = bbox[::2] *orig_image_shape[0] x1,x2 = bbox[1::2]*orig_image_shape[1] return [int(y1),int(x1),int(y2),int(x2)] def detect_using_clip(image,prompts=[],threshould=0.4): model_detections = dict() inputs = processor( text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = model(**inputs) preds = outputs.logits.unsqueeze(1) detection = outputs.logits[0] # Assuming class index 0 for i,prompt in enumerate(prompts): predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_image = np.where(predicted_image>threshould,255,0) # extract countours from the image lbl_0 = label(predicted_image) props = regionprops(lbl_0) model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props] return model_detections def display_images(image,detections,prompt='traffic light'): H,W = image.shape[:2] image_copy = image.copy() if prompt not in detections.keys(): print("prompt not in query ..") return image_copy for bbox in detections[prompt]: cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2) return image_copy def shot(image, labels_text): print(labels_text) prompts = labels_text.split(',') global classes classes = prompts print(classes) detections = detect_using_clip(image,prompts=prompts) print(detections) return 0 def add_text(text): labels = text.split(',') return labels with gr.Blocks(title="Zero Shot Object ddetection using Text Prompts") as demo : gr.Markdown( """