ydshieh HF staff commited on
Commit
8fdd748
1 Parent(s): facea15

Upload processing_kosmos2.py

Browse files
Files changed (1) hide show
  1. processing_kosmos2.py +117 -2
processing_kosmos2.py CHANGED
@@ -59,7 +59,7 @@ class Kosmos2Processor(ProcessorMixin):
59
  """
60
  attributes = ["image_processor", "tokenizer"]
61
  image_processor_class = "CLIPImageProcessor"
62
- tokenizer_class = "AutoTokenizer" # ("Kosmos2Tokenizer", "Kosmos2TokenizerFast")
63
 
64
  def __init__(self, image_processor, tokenizer):
65
  tokenizer.return_token_type_ids = False
@@ -332,7 +332,9 @@ class Kosmos2Processor(ProcessorMixin):
332
  return self.tokenizer.decode(*args, **kwargs)
333
 
334
  def post_processor_generation(self, text):
335
- return text.split("</image>")[-1]
 
 
336
 
337
  @property
338
  # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
@@ -455,6 +457,7 @@ def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patch
455
 
456
 
457
  # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
 
458
  def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
459
  """
460
  Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
@@ -496,3 +499,115 @@ def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: in
496
  y2 = lr_y * cell_size + cell_size / 2
497
 
498
  return x1, y1, x2, y2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  """
60
  attributes = ["image_processor", "tokenizer"]
61
  image_processor_class = "CLIPImageProcessor"
62
+ tokenizer_class = ("Kosmos2Tokenizer", "Kosmos2TokenizerFast")
63
 
64
  def __init__(self, image_processor, tokenizer):
65
  tokenizer.return_token_type_ids = False
 
332
  return self.tokenizer.decode(*args, **kwargs)
333
 
334
  def post_processor_generation(self, text):
335
+
336
+ caption = text.split("</image>")[-1]
337
+ return clean_text_and_extract_entities_with_bboxes(caption)
338
 
339
  @property
340
  # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
 
457
 
458
 
459
  # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
460
+ # (with format modifications)
461
  def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
462
  """
463
  Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
 
499
  y2 = lr_y * cell_size + cell_size / 2
500
 
501
  return x1, y1, x2, y2
502
+
503
+
504
+ # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33
505
+ # (with format modifications)
506
+ def extract_entities_with_patch_indices(text):
507
+ # The regular expression pattern for matching the required formats
508
+ pattern = r'(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>'
509
+
510
+ # Find all matches in the given string
511
+ matches = re.finditer(pattern, text)
512
+
513
+ # Initialize an empty list to store the valid patch_index combinations
514
+ entities_with_patch_indices = []
515
+
516
+ for match in matches:
517
+ # span of a `phrase` that is between <phrase> and </phrase>
518
+ span = match.span(2)
519
+ phrase_tag, phrase, match_content = match.groups()
520
+ if not phrase_tag:
521
+ phrase = None
522
+ span = (None, None)
523
+
524
+ # Split the match_content by the delimiter to get individual patch_index pairs
525
+ patch_index_pairs = match_content.split('</delimiter_of_multi_objects/>')
526
+
527
+ entity_bboxes = []
528
+ for pair in patch_index_pairs:
529
+ # Extract the xxxx and yyyy values from the patch_index pair
530
+ x = re.search(r'<patch_index_(\d+)>', pair)
531
+ y = re.search(r'<patch_index_(\d+)>', pair[1:])
532
+
533
+ if x and y:
534
+ if phrase:
535
+ entity_bboxes.append((int(x.group(1)), int(y.group(1))))
536
+ else:
537
+ entity_bboxes.append((int(x.group(1)), int(y.group(1))))
538
+
539
+ if phrase:
540
+ entities_with_patch_indices.append((phrase, span, entity_bboxes))
541
+ else:
542
+ for bbox in entity_bboxes:
543
+ # fake entity name
544
+ entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>"
545
+ entities_with_patch_indices.append((entity, span, [bbox]))
546
+
547
+ def cleanup_spaces(text, entities):
548
+ new_text = text.strip()
549
+
550
+ leading_spaces = text - text.lstrip(text)
551
+
552
+ new_entities = []
553
+ for entity_name, (start, end), bboxes in entities:
554
+
555
+ start = start - leading_spaces + (entity_name.lstrip(entity_name))
556
+ end = end - leading_spaces - (entity_name.rstrip(entity_name))
557
+ entity_name = entity_name.strip()
558
+
559
+ new_entities.append((entity_name, (start, end), bboxes))
560
+
561
+ return new_text, new_entities
562
+
563
+ return cleanup_spaces(entities_with_patch_indices)
564
+
565
+
566
+ # TODO: Be careful
567
+ def remove_special_fields(text):
568
+ return re.sub('<.*?>', '', text)
569
+
570
+
571
+ def adjust_entity_positions(entity, text):
572
+
573
+ entity_name, (start, end) = entity
574
+ adjusted_start = len(remove_special_fields(text[:start]))
575
+ adjusted_end = len(remove_special_fields(text[:end]))
576
+ adjusted_entity = (entity_name, (adjusted_start, adjusted_end))
577
+ return adjusted_entity
578
+
579
+
580
+ # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87
581
+ # (with format modifications)
582
+ def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32):
583
+
584
+ processed_text = remove_special_fields(text)
585
+
586
+ entities_with_patch_indices = extract_entities_with_patch_indices(text)
587
+ entities = []
588
+ for item in entities_with_patch_indices:
589
+ entity, bboxes = item[0:2], item[2]
590
+ adjusted_entity = adjust_entity_positions(entity, text)
591
+ bboxes_in_coords = list(map(lambda bbox: patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side), bboxes))
592
+
593
+ entities.append((adjusted_entity) + (bboxes_in_coords,))
594
+
595
+ def cleanup_spaces(text, entities):
596
+ new_text = text.strip()
597
+ leading_spaces = len(text) - len(text.lstrip())
598
+
599
+ new_entities = []
600
+ for entity_name, (start, end), bboxes in entities:
601
+
602
+ entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip())
603
+ entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip())
604
+
605
+ start = start - leading_spaces + entity_name_leading_spaces
606
+ end = end - leading_spaces - entity_name_trailing_spaces
607
+ entity_name = entity_name.strip()
608
+
609
+ new_entities.append((entity_name, (start, end), bboxes))
610
+
611
+ return new_text, new_entities
612
+
613
+ return cleanup_spaces(processed_text, entities)