ydshieh commited on
Commit
cd04261
1 Parent(s): cb8dbda
Files changed (1) hide show
  1. app.py +125 -5
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
-
3
  import numpy as np
4
  import os
5
  import requests
@@ -9,6 +9,90 @@ from PIL import Image
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
  import cv2
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def is_overlapping(rect1, rect2):
14
  x1, y1, x2, y2 = rect1
@@ -62,12 +146,20 @@ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
62
  text_offset_original = text_height - base_height
63
  text_spaces = 3
64
 
 
 
 
 
65
  for entity_name, (start, end), bboxes in entities:
66
- for (x1_norm, y1_norm, x2_norm, y2_norm) in bboxes:
 
 
 
67
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
 
68
  # draw bbox
69
  # random color
70
- color = tuple(np.random.randint(0, 255, size=3).tolist())
71
  new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
72
 
73
  l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
@@ -131,6 +223,12 @@ def main():
131
 
132
  def generate_predictions(image_input, text_input, do_sample, sampling_topp, sampling_temperature):
133
 
 
 
 
 
 
 
134
  if text_input == "Brief":
135
  text_input = "<grounding>An image of"
136
  elif text_input == "Detailed":
@@ -156,7 +254,29 @@ def main():
156
 
157
  annotated_image = draw_entity_boxes_on_image(image_input, entities, show=True)
158
 
159
- return annotated_image, processed_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  term_of_use = """
162
  ### Terms of use
@@ -191,7 +311,7 @@ def main():
191
  label="Generated Description",
192
  combine_adjacent=False,
193
  show_legend=True,
194
- ).style(color_map={"box": "red"})
195
 
196
  with gr.Row():
197
  with gr.Column():
 
1
  import gradio as gr
2
+ import random
3
  import numpy as np
4
  import os
5
  import requests
 
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
  import cv2
11
 
12
+ colors = [
13
+ (255, 255, 0),
14
+ (255, 0, 255),
15
+ (0, 255, 255),
16
+
17
+ (255, 0, 0),
18
+ (0, 255, 0),
19
+ (0, 0, 255),
20
+
21
+ (255, 128, 0),
22
+ (255, 0, 128),
23
+ (0, 255, 128),
24
+
25
+ (128, 255, 0),
26
+ (128, 0, 255),
27
+ (0, 128, 255),
28
+
29
+ (255, 128, 128),
30
+ (128, 255, 128),
31
+ (128, 128, 255),
32
+
33
+ (128, 255, 255),
34
+ (255, 128, 255),
35
+ (255, 255, 128),
36
+
37
+ (255, 128, 64),
38
+ (255, 64, 128),
39
+ (64, 255, 128),
40
+
41
+ (128, 255, 64),
42
+ (128, 64, 255),
43
+ (64, 128, 255),
44
+
45
+ (255, 64, 64),
46
+ (64, 255, 64),
47
+ (64, 64, 255),
48
+
49
+ (64, 255, 255),
50
+ (255, 64, 255),
51
+ (255, 255, 64),
52
+
53
+ (128, 64, 64),
54
+ (64, 128, 64),
55
+ (64, 64, 128),
56
+
57
+ (64, 128, 128),
58
+ (128, 64, 128),
59
+ (128, 128, 64),
60
+
61
+ (128, 128, 0),
62
+ (128, 0, 128),
63
+ (0, 128, 128),
64
+
65
+ (128, 0, 0),
66
+ (0, 128, 0),
67
+ (0, 0, 128),
68
+
69
+ (64, 64, 0),
70
+ (64, 0, 64),
71
+ (0, 64, 64),
72
+
73
+ (64, 0, 0),
74
+ (0, 64, 0),
75
+ (0, 0, 64),
76
+
77
+ (255, 64, 0),
78
+ (255, 0, 64),
79
+ (0, 255, 64),
80
+
81
+ (64, 255, 0),
82
+ (64, 0, 255),
83
+ (0, 64, 255),
84
+
85
+ (128, 64, 0),
86
+ (128, 0, 64),
87
+ (0, 128, 64),
88
+
89
+ (64, 128, 0),
90
+ (128, 0, 255),
91
+ (0, 64, 128),
92
+ ]
93
+
94
+ color_map = {f"color_id_{color_id}": "red" for color_id, color in enumerate(colors)}
95
+
96
 
97
  def is_overlapping(rect1, rect2):
98
  x1, y1, x2, y2 = rect1
 
146
  text_offset_original = text_height - base_height
147
  text_spaces = 3
148
 
149
+ # num_bboxes = sum(len(x[-1]) for x in entities)
150
+ used_colors = colors # random.sample(colors, k=num_bboxes)
151
+
152
+ color_id = -1
153
  for entity_name, (start, end), bboxes in entities:
154
+ color_id += 1
155
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
156
+ if start is None and bbox_id > 0:
157
+ color_id += 1
158
  orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
159
+
160
  # draw bbox
161
  # random color
162
+ color = used_colors[bbox_id] # tuple(np.random.randint(0, 255, size=3).tolist())
163
  new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
164
 
165
  l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
 
223
 
224
  def generate_predictions(image_input, text_input, do_sample, sampling_topp, sampling_temperature):
225
 
226
+ user_image_path = "/tmp/user_input_test_image.jpg"
227
+ # This will be of `.jpg` format
228
+ image_input.save(user_image_path)
229
+ # This might give different results from the original argument `image_input`
230
+ image_input = Image.open(user_image_path)
231
+
232
  if text_input == "Brief":
233
  text_input = "<grounding>An image of"
234
  elif text_input == "Detailed":
 
254
 
255
  annotated_image = draw_entity_boxes_on_image(image_input, entities, show=True)
256
 
257
+ color_id = -1
258
+ entity_info = []
259
+ for entity_name, (start, end), bboxes in entities:
260
+ color_id += 1
261
+ for bbox_id, _ in enumerate(bboxes):
262
+ if start is None and bbox_id > 0:
263
+ color_id += 1
264
+ if start is not None:
265
+ entity_info.append(((start, end), color_id))
266
+
267
+ colored_text = []
268
+ prev_start = 0
269
+ end = 0
270
+ for idx, ((start, end), color_id) in enumerate(entity_info):
271
+ if start > prev_start:
272
+ colored_text.append((processed_text[prev_start:start], None))
273
+ colored_text.append((processed_text[start:end], f"color_id_{color_id}"))
274
+ prev_start = start
275
+
276
+ if end < len(processed_text):
277
+ colored_text.append((processed_text[end:len(processed_text)], None))
278
+
279
+ return annotated_image, colored_text
280
 
281
  term_of_use = """
282
  ### Terms of use
 
311
  label="Generated Description",
312
  combine_adjacent=False,
313
  show_legend=True,
314
+ ).style(color_map=color_map)
315
 
316
  with gr.Row():
317
  with gr.Column():