banao-tech commited on
Commit
24bf6bb
·
verified ·
1 Parent(s): d72c9a4

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +335 -229
utils.py CHANGED
@@ -1,49 +1,58 @@
1
- # from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import io
4
  import base64
5
  import time
6
- from PIL import Image, ImageDraw, ImageFont
7
- import json
8
- import requests
9
- # utility function
10
- import os
11
- from openai import AzureOpenAI
12
-
13
  import json
14
  import sys
15
- import os
16
- import cv2
 
 
17
  import numpy as np
18
- # %matplotlib inline
 
19
  from matplotlib import pyplot as plt
 
20
  import easyocr
21
  from paddleocr import PaddleOCR
 
 
 
 
 
 
 
 
 
22
  reader = easyocr.Reader(['en'])
23
  paddle_ocr = PaddleOCR(
24
- lang='en', # other lang also available
25
  use_angle_cls=False,
26
- use_gpu=False, # using cuda will conflict with pytorch in the same process
27
  show_log=False,
28
  max_batch_size=1024,
29
  use_dilation=True, # improves accuracy
30
  det_db_score_mode='slow', # improves accuracy
31
- rec_batch_num=1024)
32
- import time
33
- import base64
34
-
35
- import os
36
- import ast
37
- import torch
38
- from typing import Tuple, List
39
- from torchvision.ops import box_convert
40
- import re
41
- from torchvision.transforms import ToPILImage
42
- import supervision as sv
43
- import torchvision.transforms as T
44
 
45
 
46
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
 
 
 
 
47
  if not device:
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  if model_name == "blip2":
@@ -51,44 +60,62 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2
51
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
52
  if device == 'cpu':
53
  model = Blip2ForConditionalGeneration.from_pretrained(
54
- model_name_or_path, device_map=None, torch_dtype=torch.float32
55
- )
56
  else:
57
  model = Blip2ForConditionalGeneration.from_pretrained(
58
- model_name_or_path, device_map=None, torch_dtype=torch.float16
59
- ).to(device)
60
  elif model_name == "florence2":
61
- from transformers import AutoProcessor, AutoModelForCausalLM
62
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
63
  if device == 'cpu':
64
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
 
 
65
  else:
66
- model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
 
 
67
  return {'model': model.to(device), 'processor': processor}
68
 
69
 
70
  def get_yolo_model(model_path):
 
 
 
71
  from ultralytics import YOLO
72
- # Load the model.
73
  model = YOLO(model_path)
74
  return model
75
 
76
 
77
  @torch.inference_mode()
78
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
79
- # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
80
-
 
 
 
 
 
 
 
 
 
 
 
 
81
  to_pil = ToPILImage()
82
  if starting_idx:
83
  non_ocr_boxes = filtered_boxes[starting_idx:]
84
  else:
85
  non_ocr_boxes = filtered_boxes
86
- croped_pil_image = []
87
- for i, coord in enumerate(non_ocr_boxes):
88
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
89
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
90
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
91
- croped_pil_image.append(to_pil(cropped_image))
92
 
93
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
94
  if not prompt:
@@ -99,17 +126,29 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
99
 
100
  generated_texts = []
101
  device = model.device
102
- for i in range(0, len(croped_pil_image), batch_size):
103
- start = time.time()
104
- batch = croped_pil_image[i:i+batch_size]
105
  if model.device.type == 'cuda':
106
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
107
  else:
108
- inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
109
  if 'florence' in model.config.name_or_path:
110
- generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=100,num_beams=3, do_sample=False)
 
 
 
 
 
 
111
  else:
112
- generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
 
 
 
 
 
 
 
113
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
114
  generated_text = [gen.strip() for gen in generated_text]
115
  generated_texts.extend(generated_text)
@@ -117,52 +156,56 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
117
  return generated_texts
118
 
119
 
120
-
121
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
 
 
 
122
  to_pil = ToPILImage()
123
  if ocr_bbox:
124
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
125
  else:
126
  non_ocr_boxes = filtered_boxes
127
- croped_pil_image = []
128
- for i, coord in enumerate(non_ocr_boxes):
129
- xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
130
- ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
131
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
132
- croped_pil_image.append(to_pil(cropped_image))
133
 
134
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
135
  device = model.device
136
- messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
137
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
138
 
139
  batch_size = 5 # Number of samples per batch
140
  generated_texts = []
141
 
142
- for i in range(0, len(croped_pil_image), batch_size):
143
- images = croped_pil_image[i:i+batch_size]
144
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
145
- inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
146
  texts = [prompt] * len(images)
147
- for i, txt in enumerate(texts):
148
- input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
149
- inputs['input_ids'].append(input['input_ids'])
150
- inputs['attention_mask'].append(input['attention_mask'])
151
- inputs['pixel_values'].append(input['pixel_values'])
152
- inputs['image_sizes'].append(input['image_sizes'])
153
- max_len = max([x.shape[1] for x in inputs['input_ids']])
154
- for i, v in enumerate(inputs['input_ids']):
155
- inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
156
- inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
 
 
157
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
158
 
159
- generation_args = {
160
- "max_new_tokens": 25,
161
- "temperature": 0.01,
162
- "do_sample": False,
163
- }
164
- generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
165
- # # remove input tokens
166
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
167
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
168
  response = [res.strip('\n').strip() for res in response]
@@ -170,7 +213,19 @@ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, captio
170
 
171
  return generated_texts
172
 
 
173
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
 
 
 
 
 
 
 
 
 
 
 
174
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
175
 
176
  def box_area(box):
@@ -184,39 +239,30 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
184
  return max(0, x2 - x1) * max(0, y2 - y1)
185
 
186
  def IoU(box1, box2):
187
- intersection = intersection_area(box1, box2)
188
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
189
- if box_area(box1) > 0 and box_area(box2) > 0:
190
- ratio1 = intersection / box_area(box1)
191
- ratio2 = intersection / box_area(box2)
192
- else:
193
- ratio1, ratio2 = 0, 0
194
- return max(intersection / union, ratio1, ratio2)
195
 
196
  def is_inside(box1, box2):
197
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
198
- intersection = intersection_area(box1, box2)
199
- ratio1 = intersection / box_area(box1)
200
- return ratio1 > 0.95
201
 
202
  boxes = boxes.tolist()
203
  filtered_boxes = []
204
  if ocr_bbox:
205
  filtered_boxes.extend(ocr_bbox)
206
- # print('ocr_bbox!!!', ocr_bbox)
207
  for i, box1 in enumerate(boxes):
208
- # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
209
  is_valid_box = True
210
  for j, box2 in enumerate(boxes):
211
- # keep the smaller box
212
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
213
  is_valid_box = False
214
  break
215
  if is_valid_box:
216
- # add the following 2 lines to include ocr bbox
217
  if ocr_bbox:
218
- # only add the box if it does not overlap with any ocr bbox
219
- if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
220
  filtered_boxes.append(box1)
221
  else:
222
  filtered_boxes.append(box1)
@@ -224,11 +270,17 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
224
 
225
 
226
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
227
- '''
228
- ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
229
- boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
230
-
231
- '''
 
 
 
 
 
 
232
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
233
 
234
  def box_area(box):
@@ -242,99 +294,91 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
242
  return max(0, x2 - x1) * max(0, y2 - y1)
243
 
244
  def IoU(box1, box2):
245
- intersection = intersection_area(box1, box2)
246
- union = box_area(box1) + box_area(box2) - intersection + 1e-6
247
- if box_area(box1) > 0 and box_area(box2) > 0:
248
- ratio1 = intersection / box_area(box1)
249
- ratio2 = intersection / box_area(box2)
250
- else:
251
- ratio1, ratio2 = 0, 0
252
- return max(intersection / union, ratio1, ratio2)
253
 
254
  def is_inside(box1, box2):
255
- # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
256
- intersection = intersection_area(box1, box2)
257
- ratio1 = intersection / box_area(box1)
258
- return ratio1 > 0.80
259
 
260
- # boxes = boxes.tolist()
261
  filtered_boxes = []
262
  if ocr_bbox:
263
  filtered_boxes.extend(ocr_bbox)
264
- # print('ocr_bbox!!!', ocr_bbox)
265
  for i, box1_elem in enumerate(boxes):
266
  box1 = box1_elem['bbox']
267
  is_valid_box = True
268
  for j, box2_elem in enumerate(boxes):
269
- # keep the smaller box
270
  box2 = box2_elem['bbox']
271
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
272
  is_valid_box = False
273
  break
274
  if is_valid_box:
275
- # add the following 2 lines to include ocr bbox
276
  if ocr_bbox:
277
- # keep yolo boxes + prioritize ocr label
278
  box_added = False
279
  for box3_elem in ocr_bbox:
280
- if not box_added:
281
- box3 = box3_elem['bbox']
282
- if is_inside(box3, box1): # ocr inside icon
283
- # box_added = True
284
- # delete the box3_elem from ocr_bbox
285
- try:
286
- filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
287
- filtered_boxes.remove(box3_elem)
288
- # print('remove ocr bbox:', box3_elem)
289
- except:
290
- continue
291
- # break
292
- elif is_inside(box1, box3): # icon inside ocr
293
- box_added = True
294
- # try:
295
- # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
296
- # filtered_boxes.remove(box3_elem)
297
- # except:
298
- # continue
299
- break
300
- else:
301
  continue
 
 
 
302
  if not box_added:
303
- filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
304
-
 
 
 
 
305
  else:
306
  filtered_boxes.append(box1)
307
- return filtered_boxes # torch.tensor(filtered_boxes)
308
 
309
 
310
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
311
- transform = T.Compose(
312
- [
313
- T.RandomResize([800], max_size=1333),
314
- T.ToTensor(),
315
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
316
- ]
317
- )
 
 
 
 
 
318
  image_source = Image.open(image_path).convert("RGB")
319
  image = np.asarray(image_source)
320
  image_transformed, _ = transform(image_source, None)
321
  return image, image_transformed
322
 
323
 
324
- def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
325
- text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
326
- """
327
- This function annotates an image with bounding boxes and labels.
328
-
329
- Parameters:
330
- image_source (np.ndarray): The source image to be annotated.
331
- boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
332
- logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
333
- phrases (List[str]): A list of labels for each bounding box.
334
- text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
335
-
336
  Returns:
337
- np.ndarray: The annotated image.
338
  """
339
  h, w, _ = image_source.shape
340
  boxes = boxes * torch.Tensor([w, h, w, h])
@@ -344,30 +388,43 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
344
 
345
  labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
346
 
347
- from util.box_annotator import BoxAnnotator
348
- box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
 
 
349
  annotated_frame = image_source.copy()
350
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
351
 
352
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
353
  return annotated_frame, label_coordinates
354
 
355
 
356
  def predict(model, image, caption, box_threshold, text_threshold):
357
- """ Use huggingface model to replace the original model
358
  """
359
- model, processor = model['model'], model['processor']
360
- device = model.device
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
363
  with torch.no_grad():
364
- outputs = model(**inputs)
365
 
366
  results = processor.post_process_grounded_object_detection(
367
  outputs,
368
  inputs.input_ids,
369
- box_threshold=box_threshold, # 0.4,
370
- text_threshold=text_threshold, # 0.3,
371
  target_sizes=[image.size[::-1]]
372
  )[0]
373
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
@@ -375,71 +432,106 @@ def predict(model, image, caption, box_threshold, text_threshold):
375
 
376
 
377
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
378
- """Use YOLO model for object detection with correct parameters"""
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  kwargs = {
380
- 'conf': box_threshold, # Correct confidence parameter
381
- 'iou': iou_threshold, # Correct IoU parameter
382
  'verbose': False
383
  }
384
-
385
  if scale_img:
386
  kwargs['imgsz'] = imgsz
387
-
388
  results = model.predict(image_path, **kwargs)
389
  boxes = results[0].boxes.xyxy
390
  conf = results[0].boxes.conf
391
  return boxes, conf, [str(i) for i in range(len(boxes))]
392
 
393
 
394
- def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=None):
395
- """ ocr_bbox: list of xyxy format bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  """
397
  image_source = Image.open(img_path).convert("RGB")
398
  w, h = image_source.size
399
  if not imgsz:
400
  imgsz = (h, w)
401
- # print('image size:', w, h)
402
- xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
 
 
 
403
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
404
- image_source = np.asarray(image_source)
405
  phrases = [str(i) for i in range(len(phrases))]
406
 
407
- # annotate the image with labels
408
- h, w, _ = image_source.shape
409
  if ocr_bbox:
410
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
411
- ocr_bbox=ocr_bbox.tolist()
412
  else:
413
  print('no ocr bbox!!!')
414
  ocr_bbox = None
415
- # filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
416
- # starting_idx = len(ocr_bbox)
417
- # print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
418
 
419
- ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)]
420
- xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()]
 
 
421
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
422
 
423
- # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
424
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
425
- # get the index of the first 'content': None
426
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
427
- filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
428
 
429
-
430
- # get parsed icon local semantics
431
  if use_local_semantics:
432
  caption_model = caption_model_processor['model']
433
- if 'phi3_v' in caption_model.config.model_type:
434
- parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
435
  else:
436
- parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
437
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
438
  icon_start = len(ocr_text)
439
  parsed_content_icon_ls = []
440
- # fill the filtered_boxes_elem None content with parsed_content_icon in order
441
- for i, box in enumerate(filtered_boxes_elem):
442
- if box['content'] is None:
443
  box['content'] = parsed_content_icon.pop(0)
444
  for i, txt in enumerate(parsed_content_icon):
445
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
@@ -448,51 +540,72 @@ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_
448
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
449
  parsed_content_merged = ocr_text
450
 
451
- filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
452
-
453
- phrases = [i for i in range(len(filtered_boxes))]
454
 
455
- # draw boxes
456
  if draw_bbox_config:
457
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
 
 
458
  else:
459
- annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
 
 
 
460
 
461
  pil_img = Image.fromarray(annotated_frame)
462
  buffered = io.BytesIO()
463
  pil_img.save(buffered, format="PNG")
464
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
 
465
  if output_coord_in_ratio:
466
- # h, w, _ = image_source.shape
467
- label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
468
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
469
 
470
  return encoded_image, label_coordinates, filtered_boxes_elem
471
 
472
 
473
  def get_xywh(input):
474
- x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
475
- x, y, w, h = int(x), int(y), int(w), int(h)
476
- return x, y, w, h
 
 
 
 
 
477
 
478
  def get_xyxy(input):
479
- x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
480
- x, y, xp, yp = int(x), int(y), int(xp), int(yp)
481
- return x, y, xp, yp
 
 
 
 
482
 
483
  def get_xywh_yolo(input):
484
- x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
485
- x, y, w, h = int(x), int(y), int(w), int(h)
486
- return x, y, w, h
487
-
 
 
 
488
 
489
 
490
- def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
 
 
 
 
 
 
 
 
491
  if use_paddleocr:
492
- if easyocr_args is None:
493
- text_threshold = 0.5
494
- else:
495
- text_threshold = easyocr_args['text_threshold']
496
  result = paddle_ocr.ocr(image_path, cls=False)[0]
497
  conf = [item[1] for item in result]
498
  coord = [item[0] for item in result if item[1][1] > text_threshold]
@@ -501,28 +614,21 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
501
  if easyocr_args is None:
502
  easyocr_args = {}
503
  result = reader.readtext(image_path, **easyocr_args)
504
- # print('goal filtering pred:', result[-5:])
505
  coord = [item[0] for item in result]
506
  text = [item[1] for item in result]
507
- # read the image using cv2
508
  if display_img:
509
  opencv_img = cv2.imread(image_path)
510
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
511
  bb = []
512
  for item in coord:
513
  x, y, a, b = get_xywh(item)
514
- # print(x, y, a, b)
515
  bb.append((x, y, a, b))
516
- cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
517
-
518
- # Display the image
519
  plt.imshow(opencv_img)
520
  else:
521
  if output_bb_format == 'xywh':
522
  bb = [get_xywh(item) for item in coord]
523
  elif output_bb_format == 'xyxy':
524
  bb = [get_xyxy(item) for item in coord]
525
- # print('bounding box!!!', bb)
526
  return (text, bb), goal_filtering
527
-
528
-
 
1
+ """
2
+ utils.py
3
+
4
+ This module contains utility functions for:
5
+ - Loading and processing images
6
+ - Object detection with YOLO
7
+ - OCR with EasyOCR / PaddleOCR
8
+ - Image annotation and bounding box manipulation
9
+ - Captioning / semantic parsing of detected icons
10
+ """
11
+
12
  import os
13
  import io
14
  import base64
15
  import time
 
 
 
 
 
 
 
16
  import json
17
  import sys
18
+ import re
19
+ from typing import Tuple, List
20
+
21
+ import torch
22
  import numpy as np
23
+ import cv2
24
+ from PIL import Image, ImageDraw, ImageFont
25
  from matplotlib import pyplot as plt
26
+
27
  import easyocr
28
  from paddleocr import PaddleOCR
29
+ import supervision as sv
30
+ import torchvision.transforms as T
31
+ from torchvision.transforms import ToPILImage
32
+ from torchvision.ops import box_convert
33
+
34
+ # Optional: import AzureOpenAI if used
35
+ from openai import AzureOpenAI
36
+
37
+ # Initialize OCR readers
38
  reader = easyocr.Reader(['en'])
39
  paddle_ocr = PaddleOCR(
40
+ lang='en', # other languages available
41
  use_angle_cls=False,
42
+ use_gpu=False, # using cuda might conflict with PyTorch in the same process
43
  show_log=False,
44
  max_batch_size=1024,
45
  use_dilation=True, # improves accuracy
46
  det_db_score_mode='slow', # improves accuracy
47
+ rec_batch_num=1024
48
+ )
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
52
+ """
53
+ Loads the captioning model and processor.
54
+ Supports either BLIP2 or Florence-2 models.
55
+ """
56
  if not device:
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
  if model_name == "blip2":
 
60
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
61
  if device == 'cpu':
62
  model = Blip2ForConditionalGeneration.from_pretrained(
63
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
64
+ )
65
  else:
66
  model = Blip2ForConditionalGeneration.from_pretrained(
67
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
68
+ ).to(device)
69
  elif model_name == "florence2":
70
+ from transformers import AutoProcessor, AutoModelForCausalLM
71
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
72
  if device == 'cpu':
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True
75
+ )
76
  else:
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True
79
+ ).to(device)
80
  return {'model': model.to(device), 'processor': processor}
81
 
82
 
83
  def get_yolo_model(model_path):
84
+ """
85
+ Loads a YOLO model from a given model_path using ultralytics.
86
+ """
87
  from ultralytics import YOLO
 
88
  model = YOLO(model_path)
89
  return model
90
 
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
+ """
95
+ Generates parsed textual content for detected icons from the image.
96
+
97
+ Args:
98
+ filtered_boxes: Tensor of bounding boxes.
99
+ starting_idx: Starting index for non-OCR boxes.
100
+ image_source: Original image as a NumPy array.
101
+ caption_model_processor: Dictionary with keys 'model' and 'processor'.
102
+ prompt: Optional prompt text.
103
+ batch_size: Batch size for processing.
104
+
105
+ Returns:
106
+ List of generated texts.
107
+ """
108
  to_pil = ToPILImage()
109
  if starting_idx:
110
  non_ocr_boxes = filtered_boxes[starting_idx:]
111
  else:
112
  non_ocr_boxes = filtered_boxes
113
+ cropped_pil_images = []
114
+ for coord in non_ocr_boxes:
115
+ xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
116
+ ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
117
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
118
+ cropped_pil_images.append(to_pil(cropped_image))
119
 
120
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
121
  if not prompt:
 
126
 
127
  generated_texts = []
128
  device = model.device
129
+ for i in range(0, len(cropped_pil_images), batch_size):
130
+ batch = cropped_pil_images[i:i+batch_size]
 
131
  if model.device.type == 'cuda':
132
+ inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
133
  else:
134
+ inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device)
135
  if 'florence' in model.config.name_or_path:
136
+ generated_ids = model.generate(
137
+ input_ids=inputs["input_ids"],
138
+ pixel_values=inputs["pixel_values"],
139
+ max_new_tokens=100,
140
+ num_beams=3,
141
+ do_sample=False
142
+ )
143
  else:
144
+ generated_ids = model.generate(
145
+ **inputs,
146
+ max_length=100,
147
+ num_beams=5,
148
+ no_repeat_ngram_size=2,
149
+ early_stopping=True,
150
+ num_return_sequences=1
151
+ )
152
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
153
  generated_text = [gen.strip() for gen in generated_text]
154
  generated_texts.extend(generated_text)
 
156
  return generated_texts
157
 
158
 
 
159
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
160
+ """
161
+ Generates parsed textual content for detected icons using the phi3_v model variant.
162
+ """
163
  to_pil = ToPILImage()
164
  if ocr_bbox:
165
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
166
  else:
167
  non_ocr_boxes = filtered_boxes
168
+ cropped_pil_images = []
169
+ for coord in non_ocr_boxes:
170
+ xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
171
+ ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
172
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
173
+ cropped_pil_images.append(to_pil(cropped_image))
174
 
175
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
176
  device = model.device
177
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
178
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
179
 
180
  batch_size = 5 # Number of samples per batch
181
  generated_texts = []
182
 
183
+ for i in range(0, len(cropped_pil_images), batch_size):
184
+ images = cropped_pil_images[i:i+batch_size]
185
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
186
+ inputs = {'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
187
  texts = [prompt] * len(images)
188
+ for idx, txt in enumerate(texts):
189
+ inp = processor._convert_images_texts_to_inputs(image_inputs[idx], txt, return_tensors="pt")
190
+ inputs['input_ids'].append(inp['input_ids'])
191
+ inputs['attention_mask'].append(inp['attention_mask'])
192
+ inputs['pixel_values'].append(inp['pixel_values'])
193
+ inputs['image_sizes'].append(inp['image_sizes'])
194
+ max_len = max(x.shape[1] for x in inputs['input_ids'])
195
+ for idx, v in enumerate(inputs['input_ids']):
196
+ pad_tensor = processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long)
197
+ inputs['input_ids'][idx] = torch.cat([pad_tensor, v], dim=1)
198
+ pad_att = torch.zeros(1, max_len - v.shape[1], dtype=torch.long)
199
+ inputs['attention_mask'][idx] = torch.cat([pad_att, inputs['attention_mask'][idx]], dim=1)
200
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
201
 
202
+ generation_args = {
203
+ "max_new_tokens": 25,
204
+ "temperature": 0.01,
205
+ "do_sample": False,
206
+ }
207
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
208
+ # Remove input tokens from the generated sequence
209
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
210
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
211
  response = [res.strip('\n').strip() for res in response]
 
213
 
214
  return generated_texts
215
 
216
+
217
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
218
+ """
219
+ Removes overlapping bounding boxes based on IoU and optionally considers OCR boxes.
220
+
221
+ Args:
222
+ boxes: Tensor of bounding boxes (in xyxy format).
223
+ iou_threshold: IoU threshold to determine overlaps.
224
+ ocr_bbox: Optional list of OCR bounding boxes.
225
+
226
+ Returns:
227
+ Filtered boxes as a torch.Tensor.
228
+ """
229
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
230
 
231
  def box_area(box):
 
239
  return max(0, x2 - x1) * max(0, y2 - y1)
240
 
241
  def IoU(box1, box2):
242
+ inter = intersection_area(box1, box2)
243
+ union = box_area(box1) + box_area(box2) - inter + 1e-6
244
+ ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
245
+ ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
246
+ return max(inter / union, ratio1, ratio2)
 
 
 
247
 
248
  def is_inside(box1, box2):
249
+ inter = intersection_area(box1, box2)
250
+ return (inter / box_area(box1)) > 0.95
 
 
251
 
252
  boxes = boxes.tolist()
253
  filtered_boxes = []
254
  if ocr_bbox:
255
  filtered_boxes.extend(ocr_bbox)
 
256
  for i, box1 in enumerate(boxes):
 
257
  is_valid_box = True
258
  for j, box2 in enumerate(boxes):
 
259
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
260
  is_valid_box = False
261
  break
262
  if is_valid_box:
 
263
  if ocr_bbox:
264
+ # Only add the box if it does not overlap with any OCR box
265
+ if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for box3 in ocr_bbox):
266
  filtered_boxes.append(box1)
267
  else:
268
  filtered_boxes.append(box1)
 
270
 
271
 
272
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
273
+ """
274
+ Removes overlapping boxes with OCR priority.
275
+
276
+ Args:
277
+ boxes: List of dictionaries, each with keys: 'type', 'bbox', 'interactivity', 'content'.
278
+ iou_threshold: IoU threshold for removal.
279
+ ocr_bbox: List of OCR box dictionaries.
280
+
281
+ Returns:
282
+ A list of filtered box dictionaries.
283
+ """
284
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
285
 
286
  def box_area(box):
 
294
  return max(0, x2 - x1) * max(0, y2 - y1)
295
 
296
  def IoU(box1, box2):
297
+ inter = intersection_area(box1, box2)
298
+ union = box_area(box1) + box_area(box2) - inter + 1e-6
299
+ ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
300
+ ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
301
+ return max(inter / union, ratio1, ratio2)
 
 
 
302
 
303
  def is_inside(box1, box2):
304
+ inter = intersection_area(box1, box2)
305
+ return (inter / box_area(box1)) > 0.80
 
 
306
 
 
307
  filtered_boxes = []
308
  if ocr_bbox:
309
  filtered_boxes.extend(ocr_bbox)
 
310
  for i, box1_elem in enumerate(boxes):
311
  box1 = box1_elem['bbox']
312
  is_valid_box = True
313
  for j, box2_elem in enumerate(boxes):
 
314
  box2 = box2_elem['bbox']
315
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
316
  is_valid_box = False
317
  break
318
  if is_valid_box:
 
319
  if ocr_bbox:
 
320
  box_added = False
321
  for box3_elem in ocr_bbox:
322
+ box3 = box3_elem['bbox']
323
+ if is_inside(box3, box1):
324
+ try:
325
+ filtered_boxes.append({
326
+ 'type': 'text',
327
+ 'bbox': box1_elem['bbox'],
328
+ 'interactivity': True,
329
+ 'content': box3_elem['content']
330
+ })
331
+ filtered_boxes.remove(box3_elem)
332
+ except Exception:
 
 
 
 
 
 
 
 
 
 
333
  continue
334
+ elif is_inside(box1, box3):
335
+ box_added = True
336
+ break
337
  if not box_added:
338
+ filtered_boxes.append({
339
+ 'type': 'icon',
340
+ 'bbox': box1_elem['bbox'],
341
+ 'interactivity': True,
342
+ 'content': None
343
+ })
344
  else:
345
  filtered_boxes.append(box1)
346
+ return filtered_boxes # Optionally, you could return torch.tensor(filtered_boxes) if needed
347
 
348
 
349
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
350
+ """
351
+ Loads an image and applies transformations.
352
+
353
+ Returns:
354
+ image: Original image as a NumPy array.
355
+ image_transformed: Transformed tensor.
356
+ """
357
+ transform = T.Compose([
358
+ T.RandomResize([800], max_size=1333),
359
+ T.ToTensor(),
360
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
361
+ ])
362
  image_source = Image.open(image_path).convert("RGB")
363
  image = np.asarray(image_source)
364
  image_transformed, _ = transform(image_source, None)
365
  return image, image_transformed
366
 
367
 
368
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str],
369
+ text_scale: float, text_padding=5, text_thickness=2, thickness=3) -> Tuple[np.ndarray, dict]:
370
+ """
371
+ Annotates an image with bounding boxes and labels.
372
+
373
+ Args:
374
+ image_source: Source image as a NumPy array.
375
+ boxes: Bounding boxes in cxcywh format (normalized).
376
+ logits: Confidence scores for each bounding box.
377
+ phrases: List of labels.
378
+ text_scale, text_padding, text_thickness, thickness: Annotation parameters.
379
+
380
  Returns:
381
+ Annotated image and a dictionary of label coordinates.
382
  """
383
  h, w, _ = image_source.shape
384
  boxes = boxes * torch.Tensor([w, h, w, h])
 
388
 
389
  labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
390
 
391
+ # Import the custom box annotator from your project structure.
392
+ from util.box_annotator import BoxAnnotator
393
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,
394
+ text_thickness=text_thickness, thickness=thickness)
395
  annotated_frame = image_source.copy()
396
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w, h))
397
 
398
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
399
  return annotated_frame, label_coordinates
400
 
401
 
402
  def predict(model, image, caption, box_threshold, text_threshold):
 
403
  """
404
+ Uses a Hugging Face model to perform grounded object detection.
405
+
406
+ Args:
407
+ model: Dictionary with 'model' and 'processor'.
408
+ image: Input PIL image.
409
+ caption: Caption text.
410
+ box_threshold: Confidence threshold for boxes.
411
+ text_threshold: Threshold for text detection.
412
+
413
+ Returns:
414
+ boxes, logits, phrases from the detection.
415
+ """
416
+ model_obj, processor = model['model'], model['processor']
417
+ device = model_obj.device
418
 
419
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
420
  with torch.no_grad():
421
+ outputs = model_obj(**inputs)
422
 
423
  results = processor.post_process_grounded_object_detection(
424
  outputs,
425
  inputs.input_ids,
426
+ box_threshold=box_threshold,
427
+ text_threshold=text_threshold,
428
  target_sizes=[image.size[::-1]]
429
  )[0]
430
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
 
432
 
433
 
434
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
435
+ """
436
+ Uses a YOLO model for object detection.
437
+
438
+ Args:
439
+ model: YOLO model instance.
440
+ image_path: Path to the image.
441
+ box_threshold: Confidence threshold.
442
+ imgsz: Image size for scaling (if scale_img is True).
443
+ scale_img: Boolean flag to scale the image.
444
+ iou_threshold: IoU threshold for non-max suppression.
445
+
446
+ Returns:
447
+ Bounding boxes, confidence scores, and placeholder phrases.
448
+ """
449
  kwargs = {
450
+ 'conf': box_threshold, # Confidence threshold
451
+ 'iou': iou_threshold, # IoU threshold
452
  'verbose': False
453
  }
 
454
  if scale_img:
455
  kwargs['imgsz'] = imgsz
456
+
457
  results = model.predict(image_path, **kwargs)
458
  boxes = results[0].boxes.xyxy
459
  conf = results[0].boxes.conf
460
  return boxes, conf, [str(i) for i in range(len(boxes))]
461
 
462
 
463
+ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None,
464
+ text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None,
465
+ ocr_text=[], use_local_semantics=True, iou_threshold=0.9, prompt=None, scale_img=False,
466
+ imgsz=None, batch_size=None):
467
+ """
468
+ Processes an image to generate semantic (SOM) labels.
469
+
470
+ Args:
471
+ img_path: Path to the image.
472
+ model: YOLO model for detection.
473
+ BOX_TRESHOLD: Confidence threshold for box prediction.
474
+ output_coord_in_ratio: If True, output coordinates in ratio.
475
+ ocr_bbox: OCR bounding boxes.
476
+ text_scale, text_padding: Parameters for drawing annotations.
477
+ draw_bbox_config: Custom configuration for bounding box drawing.
478
+ caption_model_processor: Dictionary with caption model and processor.
479
+ ocr_text: List of OCR-detected texts.
480
+ use_local_semantics: Whether to use local semantic processing.
481
+ iou_threshold: IoU threshold for filtering overlaps.
482
+ prompt: Optional caption prompt.
483
+ scale_img: Whether to scale the image.
484
+ imgsz: Image size for YOLO.
485
+ batch_size: Batch size for captioning.
486
+
487
+ Returns:
488
+ Encoded annotated image, label coordinates, and filtered boxes.
489
  """
490
  image_source = Image.open(img_path).convert("RGB")
491
  w, h = image_source.size
492
  if not imgsz:
493
  imgsz = (h, w)
494
+ # Run YOLO detection
495
+ xyxy, logits, phrases = predict_yolo(
496
+ model=model, image_path=img_path, box_threshold=BOX_TRESHOLD,
497
+ imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1
498
+ )
499
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
500
+ image_source_np = np.asarray(image_source)
501
  phrases = [str(i) for i in range(len(phrases))]
502
 
503
+ # Process OCR bounding boxes (if any)
 
504
  if ocr_bbox:
505
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
506
+ ocr_bbox = ocr_bbox.tolist()
507
  else:
508
  print('no ocr bbox!!!')
509
  ocr_bbox = None
 
 
 
510
 
511
+ ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt}
512
+ for box, txt in zip(ocr_bbox, ocr_text)]
513
+ xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None}
514
+ for box in xyxy.tolist()]
515
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
516
 
517
+ # Sort filtered boxes so that boxes with 'content' == None are at the end
518
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
 
519
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
520
+ filtered_boxes_tensor = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
521
 
522
+ # Generate parsed icon semantics if required
 
523
  if use_local_semantics:
524
  caption_model = caption_model_processor['model']
525
+ if 'phi3_v' in caption_model.config.model_type:
526
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes_tensor, ocr_bbox, image_source_np, caption_model_processor)
527
  else:
528
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes_tensor, starting_idx, image_source_np, caption_model_processor, prompt=prompt, batch_size=batch_size)
529
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
530
  icon_start = len(ocr_text)
531
  parsed_content_icon_ls = []
532
+ # Fill boxes with no OCR content with parsed icon content
533
+ for box in filtered_boxes_elem:
534
+ if box['content'] is None and parsed_content_icon:
535
  box['content'] = parsed_content_icon.pop(0)
536
  for i, txt in enumerate(parsed_content_icon):
537
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
 
540
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
541
  parsed_content_merged = ocr_text
542
 
543
+ filtered_boxes_cxcywh = box_convert(boxes=filtered_boxes_tensor, in_fmt="xyxy", out_fmt="cxcywh")
544
+ phrases = [i for i in range(len(filtered_boxes_cxcywh))]
 
545
 
546
+ # Annotate image with bounding boxes and labels
547
  if draw_bbox_config:
548
+ annotated_frame, label_coordinates = annotate(
549
+ image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases, **draw_bbox_config
550
+ )
551
  else:
552
+ annotated_frame, label_coordinates = annotate(
553
+ image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases,
554
+ text_scale=text_scale, text_padding=text_padding
555
+ )
556
 
557
  pil_img = Image.fromarray(annotated_frame)
558
  buffered = io.BytesIO()
559
  pil_img.save(buffered, format="PNG")
560
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
561
+
562
  if output_coord_in_ratio:
563
+ label_coordinates = {k: [v[0] / w, v[1] / h, v[2] / w, v[3] / h] for k, v in label_coordinates.items()}
 
564
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
565
 
566
  return encoded_image, label_coordinates, filtered_boxes_elem
567
 
568
 
569
  def get_xywh(input):
570
+ """
571
+ Converts a bounding box from a list of two points into (x, y, width, height).
572
+ """
573
+ x, y = input[0][0], input[0][1]
574
+ w = input[2][0] - input[0][0]
575
+ h = input[2][1] - input[0][1]
576
+ return int(x), int(y), int(w), int(h)
577
+
578
 
579
  def get_xyxy(input):
580
+ """
581
+ Converts a bounding box from a list of two points into (x, y, x2, y2).
582
+ """
583
+ x, y = input[0][0], input[0][1]
584
+ x2, y2 = input[2][0], input[2][1]
585
+ return int(x), int(y), int(x2), int(y2)
586
+
587
 
588
  def get_xywh_yolo(input):
589
+ """
590
+ Converts a YOLO-style bounding box (x1, y1, x2, y2) into (x, y, width, height).
591
+ """
592
+ x, y = input[0], input[1]
593
+ w = input[2] - input[0]
594
+ h = input[3] - input[1]
595
+ return int(x), int(y), int(w), int(h)
596
 
597
 
598
+ def check_ocr_box(image_path, display_img=True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
599
+ """
600
+ Runs OCR on the given image using PaddleOCR or EasyOCR and optionally displays annotated results.
601
+
602
+ Returns:
603
+ A tuple containing:
604
+ - A tuple (text, bounding boxes)
605
+ - The goal_filtering parameter (unchanged)
606
+ """
607
  if use_paddleocr:
608
+ text_threshold = 0.5 if easyocr_args is None else easyocr_args.get('text_threshold', 0.5)
 
 
 
609
  result = paddle_ocr.ocr(image_path, cls=False)[0]
610
  conf = [item[1] for item in result]
611
  coord = [item[0] for item in result if item[1][1] > text_threshold]
 
614
  if easyocr_args is None:
615
  easyocr_args = {}
616
  result = reader.readtext(image_path, **easyocr_args)
 
617
  coord = [item[0] for item in result]
618
  text = [item[1] for item in result]
619
+
620
  if display_img:
621
  opencv_img = cv2.imread(image_path)
622
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
623
  bb = []
624
  for item in coord:
625
  x, y, a, b = get_xywh(item)
 
626
  bb.append((x, y, a, b))
627
+ cv2.rectangle(opencv_img, (x, y), (x + a, y + b), (0, 255, 0), 2)
 
 
628
  plt.imshow(opencv_img)
629
  else:
630
  if output_bb_format == 'xywh':
631
  bb = [get_xywh(item) for item in coord]
632
  elif output_bb_format == 'xyxy':
633
  bb = [get_xyxy(item) for item in coord]
 
634
  return (text, bb), goal_filtering