File size: 2,507 Bytes
1c2d63b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee2100
1c2d63b
2ee2100
1c2d63b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import cv2
import gradio as gr
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection


# Use GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
model.eval()
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")


def query_image(img, text_queries, score_threshold):
    text_queries = text_queries
    text_queries = text_queries.split(",")

    target_sizes = torch.Tensor([img.shape[:2]])
    inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
    
    outputs.logits = outputs.logits.cpu()
    outputs.pred_boxes = outputs.pred_boxes.cpu() 
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

    font = cv2.FONT_HERSHEY_SIMPLEX

    for box, score, label in zip(boxes, scores, labels):
        box = [int(i) for i in box.tolist()]

        if score >= score_threshold:
            img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
            if box[3] + 25 > 768:
                y = box[3] - 10
            else:
                y = box[3] + 25
                
            img = cv2.putText(
                img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
            )
    return img


description = """
You can use OWL-ViT to query images with text descriptions of any object. 
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.  <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
"""
demo = gr.Interface(
    query_image, 
    inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)], 
    outputs="image",
    title="Zero-Shot Object Detection with OWL-ViT",
    description=description,
    examples=[
        ["assets/astronaut.png", "human face, rocket, star-spangled banner, nasa badge", 0.11], 
        ["assets/coffee.png", "coffee mug, spoon, plate", 0.1],
        ["assets/butterflies.jpeg", "orange butterfly", 0.3],
    ],
)
demo.launch()