andrewkatumba commited on
Commit
8c692eb
1 Parent(s): 1ebfb13

Unique colors

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -8,9 +8,16 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
9
  dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
10
 
 
 
 
 
 
 
 
 
11
  @spaces.GPU
12
  def infer(img, text_queries, score_threshold, model):
13
-
14
  if model == "dino":
15
  queries = ""
16
  for query in text_queries:
@@ -36,9 +43,9 @@ def infer(img, text_queries, score_threshold, model):
36
  if score < score_threshold:
37
  continue
38
 
39
- if model == "dino":
40
- if label != "":
41
- result_labels.append((box, label))
42
  return result_labels
43
 
44
  def query_image(img, text_queries, dino_threshold):
@@ -47,8 +54,9 @@ def query_image(img, text_queries, dino_threshold):
47
  annotations = []
48
  for box, label in dino_output:
49
  annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
50
- return (img, {"boxes": annotations})
51
-
 
52
 
53
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
54
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
 
8
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
9
  dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
10
 
11
+ def generate_colors(labels):
12
+ import random
13
+ random.seed(42)
14
+ colors = {}
15
+ for label in labels:
16
+ colors[label] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
+ return colors
18
+
19
  @spaces.GPU
20
  def infer(img, text_queries, score_threshold, model):
 
21
  if model == "dino":
22
  queries = ""
23
  for query in text_queries:
 
43
  if score < score_threshold:
44
  continue
45
 
46
+ # Only include the labels that are part of the input text_queries
47
+ if model == "dino" and label in text_queries:
48
+ result_labels.append((box, label))
49
  return result_labels
50
 
51
  def query_image(img, text_queries, dino_threshold):
 
54
  annotations = []
55
  for box, label in dino_output:
56
  annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
57
+
58
+ colors = generate_colors(text_queries)
59
+ return (img, {"boxes": annotations, "colors": colors})
60
 
61
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
62
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")