andrewkatumba commited on
Commit
d1a452d
1 Parent(s): 4550f69

reverting to initial code

Browse files
Files changed (1) hide show
  1. app.py +36 -43
app.py CHANGED
@@ -1,71 +1,64 @@
1
  import spaces
2
- from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
3
  import torch
4
  import gradio as gr
5
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
9
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
10
-
11
- def generate_colors(labels):
12
- import random
13
- random.seed(42)
14
- colors = {}
15
- for label in labels:
16
- colors[label] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
- return colors
18
 
19
  @spaces.GPU
20
  def infer(img, text_queries, score_threshold, model):
21
- if model == "dino":
22
- queries = ""
23
- for query in text_queries:
24
- queries += f"{query}. "
 
 
 
25
 
26
- height, width = img.shape[:2]
27
- target_sizes = [(height, width)]
28
- inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
29
 
30
- with torch.no_grad():
31
- outputs = dino_model(**inputs)
32
- outputs.logits = outputs.logits.cpu()
33
- outputs.pred_boxes = outputs.pred_boxes.cpu()
34
- results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
35
- box_threshold=score_threshold,
36
- target_sizes=target_sizes)
37
 
38
- boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
39
- result_labels = []
40
 
41
- for box, score, label in zip(boxes, scores, labels):
42
- box = [int(i) for i in box.tolist()]
43
- if score < score_threshold:
44
- continue
45
 
46
- # Only include the labels that are part of the input text_queries
47
- if model == "dino" and label in text_queries:
48
- result_labels.append((box, label))
49
- return result_labels
50
 
51
  def query_image(img, text_queries, dino_threshold):
52
- text_queries = [query.strip() for query in text_queries.split(",")]
 
53
  dino_output = infer(img, text_queries, dino_threshold, "dino")
54
- annotations = []
55
- for box, label in dino_output:
56
- annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
57
-
58
- colors = generate_colors(text_queries)
59
- return (img, {"boxes": annotations, "colors": colors})
60
 
61
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
62
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
63
  demo = gr.Interface(
64
  query_image,
65
  inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
66
- outputs=[dino_output],
67
  title="OWLv2 ⚔ Grounding DINO",
68
- description="Evaluate state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
69
  examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.jpg", "zebra, lion", 0.16]]
70
  )
71
  demo.launch(debug=True)
 
1
  import spaces
2
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
3
  import torch
4
  import gradio as gr
5
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
9
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
 
 
 
 
 
 
 
 
10
 
11
  @spaces.GPU
12
  def infer(img, text_queries, score_threshold, model):
13
+
14
+ if model == "dino":
15
+ queries=""
16
+ for query in text_queries:
17
+ queries += f"{query}. "
18
+
19
+ width, height = img.shape[:2]
20
 
21
+ target_sizes=[(width, height)]
22
+ inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
 
23
 
24
+ with torch.no_grad():
25
+ outputs = dino_model(**inputs)
26
+ outputs.logits = outputs.logits.cpu()
27
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
28
+ results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
29
+ box_threshold=score_threshold,
30
+ target_sizes=target_sizes)
31
 
32
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
33
+ result_labels = []
34
 
35
+ for box, score, label in zip(boxes, scores, labels):
36
+ box = [int(i) for i in box.tolist()]
37
+ if score < score_threshold:
38
+ continue
39
 
40
+ if model == "dino":
41
+ if label != "":
42
+ result_labels.append((box, label))
43
+ return result_labels
44
 
45
  def query_image(img, text_queries, dino_threshold):
46
+ text_queries = text_queries
47
+ text_queries = text_queries.split(",")
48
  dino_output = infer(img, text_queries, dino_threshold, "dino")
49
+
50
+
51
+ return (img, dino_output)
52
+
 
 
53
 
54
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
55
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
56
  demo = gr.Interface(
57
  query_image,
58
  inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
59
+ outputs=[ dino_output],
60
  title="OWLv2 ⚔ Grounding DINO",
61
+ description="Evaluate state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
62
  examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.jpg", "zebra, lion", 0.16]]
63
  )
64
  demo.launch(debug=True)