File size: 3,268 Bytes
4723159
1ebfb13
4723159
 
 
 
 
 
1ebfb13
4723159
8c692eb
 
 
 
 
 
 
 
4723159
 
1ebfb13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4723159
1ebfb13
 
4723159
1ebfb13
 
 
 
42b4893
8c692eb
 
 
1ebfb13
4723159
42b4893
4550f69
eb9c0c4
1ebfb13
 
 
8c692eb
 
 
4723159
 
 
 
 
42b4893
1ebfb13
1a89cf0
1ebfb13
d100deb
4723159
1ebfb13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import spaces
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)

def generate_colors(labels):
    import random
    random.seed(42)
    colors = {}
    for label in labels:
        colors[label] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
    return colors

@spaces.GPU
def infer(img, text_queries, score_threshold, model):
    if model == "dino":
        queries = ""
        for query in text_queries:
            queries += f"{query}. "

        height, width = img.shape[:2]
        target_sizes = [(height, width)]
        inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = dino_model(**inputs)
            outputs.logits = outputs.logits.cpu()
            outputs.pred_boxes = outputs.pred_boxes.cpu()
            results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                            box_threshold=score_threshold,
                                                                            target_sizes=target_sizes)
      
        boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
        result_labels = []

        for box, score, label in zip(boxes, scores, labels):
            box = [int(i) for i in box.tolist()]
            if score < score_threshold:
                continue

            # Only include the labels that are part of the input text_queries
            if model == "dino" and label in text_queries:
                result_labels.append((box, label))
        return result_labels

def query_image(img, text_queries, dino_threshold):
    text_queries = [query.strip() for query in text_queries.split(",")]
    dino_output = infer(img, text_queries, dino_threshold, "dino")
    annotations = []
    for box, label in dino_output:
        annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
    
    colors = generate_colors(text_queries)
    return (img, {"boxes": annotations, "colors": colors})

dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
    outputs=[dino_output],
    title="OWLv2 ⚔ Grounding DINO",
    description="Evaluate state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
    examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.jpg", "zebra, lion", 0.16]]
)
demo.launch(debug=True)