import torch
import cv2
import gradio as gr
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import pdb
from collections import OrderedDict
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
def query_image(img, text_queries,max_results):
text_queries = text_queries
text_queries = text_queries.split(",")
target_sizes = torch.Tensor([img.shape[:2]])
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
results_dict = {}
count = 0
for box, score, label in zip(boxes, scores, labels):
results_dict[count] = {"score":score.tolist(),"box":box,"label":label}
count += 1
sorted_results_dict = OrderedDict(sorted(results_dict.items(),key=lambda item: item[1]["score"],reverse=True))
font = cv2.FONT_HERSHEY_SIMPLEX
score_dist = []
count = 0
for score in sorted_results_dict:
score_dist.append(round(score,2))
count += 1
if (count == 10):
break
#for box, score, label in zip(boxes, scores, labels):
result_count = 0
for score in sorted_results_dict:
box = sorted_results_dict[score]["box"]
label = sorted_results_dict[score]["label"]
box = [int(i) for i in box.tolist()]
print("label:",label,"score:",score)
#if score >= score_threshold:
img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 1)
if box[3] + 25 > 768:
y = box[3] - 10
else:
y = box[3] + 25
rounded_score = round(score,2)
img = cv2.putText(
img, f"({rounded_score}):{text_queries[label]}", (box[0], y), font, .5, (255,0,0), 1, cv2.LINE_AA
)
result_count += 1
if (result_count >= max_results):
break
return (img,f"Top {count} score confidences:{str(score_dist)}")
description = """
Use cases of this model
1) Given an image with an object, detect it.
(e.g. Where is Waldo? app)
2) Given an image with multiple instances of an object, detect them
(e.g. labeling tool assistance for bounding box annotation)
3) Find an object within an image using either text or image as input
(e.g. Image Search app - this would require pruning candidates using a threshold and using the score distribution in the output. Search using an input image could be useful when trying to find things that are hard to describe in text like a machine part)
Note: Inference time depends on input image size. Typically images with dimensions less than 500px has response time under 5 secs on CPU.
While most examples showcased illustrate model capabilities, some illustrate model's limitations - such as finding globe,bird cage,teapot etc.Also, the model appears to have text region detection and limited text recognition capabilities
"""
demo = gr.Interface(
query_image,
inputs=[gr.Image(), "text",gr.Slider(1, 10, value=1)],
outputs=["image","text"],
server_port=80,
server_name="0.0.0.0",
title="Where is Waldo?
(implemented with OWL-ViT)",
description=description,
examples=[
["assets/Hidden_object_game_scaled.png", "bicycle", 1],
["assets/Hidden_object_game_scaled.png", "laptop", 1],
["assets/Hidden_object_game_scaled.png", "abacus", 1],
["assets/Hidden_object_game_scaled.png", "frog", 1],
["assets/Hidden_object_game_scaled.png", "bird cage", 2],
["assets/Hidden_object_game_scaled.png", "globe", 2],
["assets/Hidden_object_game_scaled.png", "teapot", 3],
["assets/bus_ovd.jpg", "license plate", 1],
["assets/bus_ovd.jpg", "sign saying ARRIVA", 1],
["assets/bus_ovd.jpg", "sign saying ARRIVAL", 1],
["assets/bus_ovd.jpg", "crossing push button", 1],
["assets/bus_ovd.jpg", "building on moutain", 2],
["assets/bus_ovd.jpg", "road marking", 3],
["assets/bus_ovd.jpg", "mirror", 1],
["assets/bus_ovd.jpg", "traffic camera", 1],
["assets/bus_ovd.jpg", "red bus,blue bus", 2],
["assets/calf.png", "snout,tail", 1],
["assets/calf.png", "hoof", 4],
["assets/calf.png", "ear", 2],
["assets/calf.png", "tag", 1],
["assets/calf.png", "hay", 1],
["assets/calf.png", "barbed wire", 1],
["assets/calf.png", "grass", 1],
["assets/calf.png", "can", 2],
["assets/road_signs.png", "STOP", 1],
["assets/road_signs.png", "STOP sign", 1],
["assets/road_signs.png", "arrow", 1],
["assets/road_signs.png", "ROAD", 1],
["assets/road_signs.png", "triangle", 1],
],
)
demo.launch()