import torch import cv2 import gradio as gr import numpy as np from transformers import OwlViTProcessor, OwlViTForObjectDetection import pdb from collections import OrderedDict # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) model.eval() processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") def query_image(img, text_queries,max_results): text_queries = text_queries text_queries = text_queries.split(",") target_sizes = torch.Tensor([img.shape[:2]]) inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = processor.post_process(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] results_dict = {} count = 0 for box, score, label in zip(boxes, scores, labels): results_dict[count] = {"score":score.tolist(),"box":box,"label":label} count += 1 sorted_results_dict = OrderedDict(sorted(results_dict.items(),key=lambda item: item[1]["score"],reverse=True)) font = cv2.FONT_HERSHEY_SIMPLEX score_dist = [] count = 0 for score in sorted_results_dict: score_dist.append(round(score,2)) count += 1 if (count == 10): break #for box, score, label in zip(boxes, scores, labels): result_count = 0 for score in sorted_results_dict: box = sorted_results_dict[score]["box"] label = sorted_results_dict[score]["label"] box = [int(i) for i in box.tolist()] print("label:",label,"score:",score) #if score >= score_threshold: img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 1) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 rounded_score = round(score,2) img = cv2.putText( img, f"({rounded_score}):{text_queries[label]}", (box[0], y), font, .5, (255,0,0), 1, cv2.LINE_AA ) result_count += 1 if (result_count >= max_results): break return (img,f"Top {count} score confidences:{str(score_dist)}") description = """
This app is a tweaked variation of Alara Dirik's OWL-ViT demo
Use cases of this model
1) Given an image with an object, detect it. (e.g. Where is Waldo? app)
2) Given an image with multiple instances of an object, detect them (e.g. labeling tool assistance for bounding box annotation)
3) Find an object within an image using either text or image as input (e.g. Image Search app - this would require pruning candidates using a threshold and using the score distribution in the output. Search using an input image could be useful when trying to find things that are hard to describe in text like a machine part)

Links to apps/notebooks of other SOTA models for open vocabulary object detection or zero-shot object detection
a) RegionCLIP
b) Colab notebook for Object-Centric-OVD

Note: Inference time depends on input image size. Typically images with dimensions less than 500px has response time under 5 secs on CPU.
While most examples showcased illustrate model capabilities, some illustrate model's limitations - such as finding globe,bird cage,teapot etc.Also, the model appears to have text region detection and limited text recognition capabilities
Images below are from  WikipediaCOCO and PASCAL VOC 2012 datasets
""" demo = gr.Interface( query_image, inputs=[gr.Image(), "text",gr.Slider(1, 10, value=1)], outputs=["image","text"], server_port=80, server_name="0.0.0.0", title="Where is Waldo? (implemented with OWL-ViT)", description=description, examples=[ ["assets/Hidden_object_game_scaled.png", "bicycle", 1], ["assets/Hidden_object_game_scaled.png", "laptop", 1], ["assets/Hidden_object_game_scaled.png", "abacus", 1], ["assets/Hidden_object_game_scaled.png", "frog", 1], ["assets/Hidden_object_game_scaled.png", "bird cage", 2], ["assets/Hidden_object_game_scaled.png", "globe", 2], ["assets/Hidden_object_game_scaled.png", "teapot", 3], ["assets/bus_ovd.jpg", "license plate", 1], ["assets/bus_ovd.jpg", "sign saying ARRIVA", 1], ["assets/bus_ovd.jpg", "sign saying ARRIVAL", 1], ["assets/bus_ovd.jpg", "crossing push button", 1], ["assets/bus_ovd.jpg", "building on moutain", 2], ["assets/bus_ovd.jpg", "road marking", 3], ["assets/bus_ovd.jpg", "mirror", 1], ["assets/bus_ovd.jpg", "traffic camera", 1], ["assets/bus_ovd.jpg", "red bus,blue bus", 2], ["assets/calf.png", "snout,tail", 1], ["assets/calf.png", "hoof", 4], ["assets/calf.png", "ear", 2], ["assets/calf.png", "tag", 1], ["assets/calf.png", "hay", 1], ["assets/calf.png", "barbed wire", 1], ["assets/calf.png", "grass", 1], ["assets/calf.png", "can", 2], ["assets/road_signs.png", "STOP", 1], ["assets/road_signs.png", "STOP sign", 1], ["assets/road_signs.png", "arrow", 1], ["assets/road_signs.png", "ROAD", 1], ["assets/road_signs.png", "triangle", 1], ], ) demo.launch()