File size: 2,641 Bytes
4723159
 
 
 
 
 
 
 
fee2c8a
4723159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42b4893
 
6ec7588
 
4723159
 
42b4893
4723159
 
eb9c0c4
4723159
 
42b4893
4723159
 
 
 
 
 
42b4893
 
1a89cf0
42b4893
 
4723159
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import spaces
from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")

@spaces.GPU
def infer(img, text_queries, score_threshold, model):
  
  if model == "dino":
    queries=""
    for query in text_queries:
      queries += f"{query}. "

    width, height = img.shape[:2]

    target_sizes=[(width, height)]
    inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = dino_model(**inputs)
      outputs.logits = outputs.logits.cpu()
      outputs.pred_boxes = outputs.pred_boxes.cpu()
      results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
                                                                    box_threshold=score_threshold,
                                                                    target_sizes=target_sizes)
      
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
  result_labels = []

  for box, score, label in zip(boxes, scores, labels):
      box = [int(i) for i in box.tolist()]
      if score < score_threshold:
          continue

      if model == "dino":
        if label != "":
            result_labels.append((box, label))
  return result_labels

def query_image(img, text_queries, dino_threshold):
    text_queries = text_queries
    text_queries = text_queries.split(",")
    dino_output = infer(img, text_queries, dino_threshold, "dino")


    return (img, dino_output)


dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
    outputs=[ dino_output],
    title="OWLv2 ⚔ Grounding DINO",
    description="Evaluate  state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
    examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.png", "zebra, lion", 0.16]]
)
demo.launch(debug=True)