|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Processor class for KOSMOS-2.""" |
|
|
|
import copy |
|
import math |
|
import re |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
|
|
from transformers.image_processing_utils import BatchFeature |
|
from transformers.image_utils import ImageInput, is_batched |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy |
|
from transformers.utils import TensorType, is_tf_available, is_torch_available |
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
|
|
|
|
BboxInput = Union[ |
|
List[Tuple[int, int]], |
|
List[Tuple[float, float, float, float]], |
|
List[List[Tuple[int, int]]], |
|
List[List[Tuple[float, float, float]]], |
|
] |
|
|
|
|
|
class Kosmos2Processor(ProcessorMixin): |
|
r""" |
|
Constructs an KOSMOS-2 processor which wraps a CLIP image processor and a KOSMOS-2 tokenizer into a single |
|
processor. |
|
|
|
[`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`Kosmos2TokenizerFast`]. See the |
|
docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`] for more information. |
|
|
|
Args: |
|
image_processor (`CLIPImageProcessor`): |
|
An instance of [`CLIPImageProcessor`]. The image processor is a required input. |
|
tokenizer (`Kosmos2TokenizerFast`): |
|
An instance of ['Kosmos2TokenizerFast`]. The tokenizer is a required input. |
|
""" |
|
attributes = ["image_processor", "tokenizer"] |
|
image_processor_class = "CLIPImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, image_processor, tokenizer): |
|
tokenizer.return_token_type_ids = False |
|
super().__init__(image_processor, tokenizer) |
|
self.current_processor = self.image_processor |
|
|
|
def __call__( |
|
self, |
|
images: ImageInput = None, |
|
text: Union[TextInput, List[TextInput]] = None, |
|
bboxes: BboxInput = None, |
|
num_image_tokens: Optional[int] = 64, |
|
first_image_token_id: Optional[int] = None, |
|
add_special_tokens: bool = True, |
|
padding: Union[bool, str, PaddingStrategy] = False, |
|
truncation: Union[bool, str, TruncationStrategy] = None, |
|
max_length: Optional[int] = None, |
|
stride: int = 0, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
return_overflowing_tokens: bool = False, |
|
return_special_tokens_mask: bool = False, |
|
return_offsets_mapping: bool = False, |
|
return_token_type_ids: bool = False, |
|
return_length: bool = False, |
|
verbose: bool = True, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
""" |
|
This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and |
|
[`Kosmos2TokenizerFast.__call__`] to prepare text for the model. |
|
|
|
Please refer to the docstring of the above two methods for more information. |
|
""" |
|
if text is None: |
|
raise ValueError("You have to specify at least `text`.") |
|
|
|
text = self.preprocess_text(text, images, bboxes, num_image_tokens=num_image_tokens) |
|
|
|
encoding = BatchFeature() |
|
|
|
text_encoding = self.tokenizer( |
|
text=text, |
|
add_special_tokens=add_special_tokens, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
stride=stride, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_attention_mask=return_attention_mask, |
|
return_overflowing_tokens=return_overflowing_tokens, |
|
return_special_tokens_mask=return_special_tokens_mask, |
|
return_offsets_mapping=return_offsets_mapping, |
|
return_token_type_ids=return_token_type_ids, |
|
return_length=return_length, |
|
verbose=verbose, |
|
return_tensors=return_tensors, |
|
**kwargs, |
|
) |
|
encoding.update(text_encoding) |
|
|
|
if images is not None: |
|
image_encoding = self.image_processor(images, return_tensors=return_tensors) |
|
encoding.update(image_encoding) |
|
|
|
|
|
if first_image_token_id is None: |
|
first_image_token_id = self.tokenizer.unk_token_id + 1 |
|
|
|
|
|
with_bos = add_special_tokens |
|
|
|
|
|
|
|
start_index = int(with_bos) + 1 |
|
|
|
if return_tensors: |
|
|
|
input_ids = np.array(encoding["input_ids"]) |
|
input_ids[:, start_index : (start_index + num_image_tokens)] = np.arange( |
|
first_image_token_id, first_image_token_id + num_image_tokens |
|
) |
|
|
|
batch_size, seq_len = input_ids.shape[:2] |
|
img_attn_mask = [] |
|
if with_bos: |
|
|
|
img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64)) |
|
|
|
img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64)) |
|
|
|
img_attn_mask.append(np.ones(shape=(batch_size, 64), dtype=np.int64)) |
|
|
|
img_attn_mask.append(np.zeros(shape=(batch_size, 1), dtype=np.int64)) |
|
|
|
seq_len -= int(with_bos) + 1 + num_image_tokens + 1 |
|
img_attn_mask.append(np.zeros(shape=(batch_size, seq_len), dtype=np.int64)) |
|
|
|
|
|
img_attn_mask = np.concatenate(img_attn_mask, axis=1) |
|
|
|
|
|
if return_tensors == "pt": |
|
input_ids = torch.from_numpy(input_ids) |
|
img_attn_mask = torch.from_numpy(img_attn_mask) |
|
elif return_tensors == "tf": |
|
input_ids = tf.convert_to_tensor(input_ids) |
|
img_attn_mask = tf.convert_to_tensor(img_attn_mask) |
|
|
|
encoding["input_ids"] = input_ids |
|
encoding["img_attn_mask"] = img_attn_mask |
|
|
|
else: |
|
|
|
|
|
image_token_ids = list(range(first_image_token_id, first_image_token_id + num_image_tokens)) |
|
base_img_attn_mask = [0] + [1] * num_image_tokens + [0] |
|
|
|
|
|
input_ids = [] |
|
img_attn_mask = [] |
|
all_input_ids = encoding["input_ids"] |
|
|
|
if isinstance(text, str): |
|
all_input_ids = [all_input_ids] |
|
for text_ids in all_input_ids: |
|
|
|
text_ids = text_ids[:start_index] + image_token_ids + text_ids[start_index + num_image_tokens :] |
|
input_ids.append(text_ids) |
|
|
|
mask = copy.copy(base_img_attn_mask) |
|
if with_bos: |
|
|
|
mask = [0] + mask |
|
|
|
mask += [0] * (len(text_ids) - len(mask)) |
|
img_attn_mask.append(mask) |
|
|
|
|
|
if isinstance(text, str): |
|
input_ids = input_ids[0] |
|
img_attn_mask = img_attn_mask[0] |
|
|
|
encoding["input_ids"] = input_ids |
|
encoding["img_attn_mask"] = img_attn_mask |
|
|
|
return encoding |
|
|
|
def preprocess_text( |
|
self, |
|
texts: Union[TextInput, List[TextInput]], |
|
images: ImageInput = None, |
|
bboxes: BboxInput = None, |
|
num_image_tokens: Optional[int] = 64, |
|
) -> Union[str, List[str]]: |
|
"""Add image and bounding box information to `texts` as image and patch index tokens. |
|
|
|
Args: |
|
texts (`Union[TextInput, List[TextInput]]`): The texts to be processed. |
|
images (`ImageInput`, *optional*): The images associated to `texts`. |
|
bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*): The bounding bboxes associated to `texts`. |
|
num_image_tokens (`int`, *optional*, defaults to 64): The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num` attribute in `Kosmos2Config`. |
|
|
|
Returns: |
|
`Union[TextInput, List[TextInput]]`: The processed texts with image and patch index tokens. |
|
""" |
|
|
|
img_tokens = ["<image>"] * num_image_tokens |
|
img_info = " ".join(["<image>"] + img_tokens + ["</image>"]) |
|
|
|
def check_bboxes_for_single_text(bboxes): |
|
""" |
|
Check `bboxes` for a single text example. It could be |
|
- `None`: no bounding box associated to a text. |
|
- A list with each element being the bounding boxes associated to one `<phrase> ... </phrase>` pair |
|
found in a text. This could be: |
|
- `None`: no bounding box associated to a `<phrase> ... </phrase>` pair. |
|
- A tuple of 2 integers: A single bounding box specified by patch indices. |
|
- A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates. |
|
- A list containing the above 2 tuple types: Multiple bounding boxes for a |
|
`<phrase> ... </phrase>` pair. |
|
""" |
|
if bboxes is None: |
|
return |
|
elif not isinstance(bboxes, list): |
|
raise ValueError("`bboxes` (for a single text example) should be `None` or a list.") |
|
|
|
|
|
for bbox in bboxes: |
|
if bbox is None: |
|
continue |
|
elif not isinstance(bbox, list): |
|
bbox = [bbox] |
|
for elt in bbox: |
|
if not isinstance(elt, tuple) or not ( |
|
(len(elt) == 2 and all(isinstance(x, int) for x in elt)) |
|
or (len(elt) == 4 and all(isinstance(x, float) for x in elt)) |
|
): |
|
raise ValueError( |
|
"Each element in `bboxes` (for a single text example) should be `None`, a tuple containing " |
|
"2 integers or 4 float point numbers, or a list containing such tuples. Also " |
|
"make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in " |
|
"batches or both for a single example." |
|
) |
|
|
|
def preprocess_single(text, image, bboxes): |
|
if image is not None: |
|
|
|
text = f"{img_info} {text}" |
|
|
|
|
|
text = self._insert_patch_index_tokens(text, bboxes) |
|
text = self._add_remove_spaces_around_tag_tokens(text) |
|
|
|
return text |
|
|
|
|
|
batched = True |
|
if isinstance(texts, str): |
|
batched = False |
|
texts = [texts] |
|
|
|
if images is None: |
|
images = [None] * len(texts) |
|
elif not is_batched(images): |
|
images = [images] |
|
if len(texts) != len(images): |
|
raise ValueError( |
|
f"The number of examples in `texts` and `images` should be the same. Got {len(texts)} v.s. {len(images)} instead." |
|
) |
|
|
|
if not batched: |
|
check_bboxes_for_single_text(bboxes) |
|
bboxes = [bboxes] |
|
elif bboxes is not None: |
|
if not isinstance(bboxes, list): |
|
raise ValueError("`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.") |
|
for x in bboxes: |
|
check_bboxes_for_single_text(x) |
|
else: |
|
bboxes = [None] * len(texts) |
|
|
|
if len(bboxes) != len(texts): |
|
raise ValueError( |
|
f"The number of examples in `texts` and `bboxes` should be the same. Got {len(texts)} v.s. {len(bboxes)} instead." |
|
) |
|
|
|
result = [preprocess_single(text, image, bbox) for text, image, bbox in zip(texts, images, bboxes)] |
|
|
|
if not batched: |
|
result = result[0] |
|
|
|
return result |
|
|
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer |
|
to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
def post_processor_generation(self, text, cleanup_and_extract=True): |
|
|
|
caption = text.split("</image>")[-1] |
|
if cleanup_and_extract: |
|
return clean_text_and_extract_entities_with_bboxes(caption) |
|
return caption |
|
|
|
@property |
|
|
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
def _insert_patch_index_tokens(self, text: str, bboxes: Union[List[Tuple[int]], List[Tuple[float]]]) -> str: |
|
if bboxes is None or len(bboxes) == 0: |
|
return text |
|
|
|
matched_phrases = list(re.finditer(r"<phrase>.+?</phrase>", string=text)) |
|
if len(matched_phrases) != len(bboxes): |
|
raise ValueError( |
|
f"The number of elements in `bboxes` should be the same as the number of `<phrase> ... </phrase>` pairs in `text`. Got {len(matched_phrases)} v.s. {len(bboxes)} instead." |
|
) |
|
|
|
|
|
|
|
curr_pos = 0 |
|
buffer = [] |
|
for matched, bbox in zip(matched_phrases, bboxes): |
|
_, end = matched.span() |
|
buffer.append(text[curr_pos:end]) |
|
curr_pos = end |
|
|
|
if bbox is None: |
|
continue |
|
|
|
if isinstance(bbox, tuple): |
|
bbox = [bbox] |
|
patch_index_strings = [] |
|
|
|
for box in bbox: |
|
patch_index_1, patch_index_2 = self._convert_bbox_to_patch_index_tokens(box) |
|
patch_index_strings.append(f"{patch_index_1} {patch_index_2}") |
|
position_str = " </delimiter_of_multi_objects/> ".join(patch_index_strings) |
|
buffer.append(f"<object> {position_str} </object>") |
|
|
|
if curr_pos < len(text): |
|
buffer.append(text[curr_pos:]) |
|
|
|
text = "".join(buffer) |
|
return text |
|
|
|
def _convert_bbox_to_patch_index_tokens( |
|
self, bbox: Union[Tuple[int, int], Tuple[float, float, float, float]] |
|
) -> Tuple[str, str]: |
|
|
|
if len(bbox) == 2: |
|
idx_1, idx_2 = bbox |
|
|
|
else: |
|
|
|
num_patches_per_side = int(math.sqrt(self.tokenizer.num_patch_index_tokens)) |
|
idx_1, idx_2 = coordinate_to_patch_index(bbox, num_patches_per_side) |
|
|
|
token_1 = f"<patch_index_{str(idx_1).zfill(4)}>" |
|
token_2 = f"<patch_index_{str(idx_2).zfill(4)}>" |
|
|
|
return token_1, token_2 |
|
|
|
def _add_remove_spaces_around_tag_tokens(self, text): |
|
""" |
|
Remove spaces before tag tokens (e.g. `<x>`). Also ensure a space after a tag token, if it is not followed by |
|
another tag token (this is not technically necessary, but good for a standard/consistent format). This avoids |
|
the inconsistency of tokenization results between kosmos-2 slow and fast tokenizers. |
|
""" |
|
|
|
tag_tokens = set( |
|
self.tokenizer.tag_tokens |
|
+ [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.tokenizer.num_patch_index_tokens)] |
|
) |
|
pattern = "|".join(tag_tokens) |
|
splits = re.split(rf"({pattern})", text) |
|
|
|
output = "" |
|
prev_str_in_targets = False |
|
for split in splits: |
|
if split in tag_tokens: |
|
prev_str_in_targets = True |
|
output = output.rstrip() + split |
|
else: |
|
|
|
|
|
if prev_str_in_targets and not split.startswith(" "): |
|
output += " " + split |
|
else: |
|
output += split |
|
prev_str_in_targets = False |
|
|
|
return output |
|
|
|
|
|
def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patches_per_side: int) -> Tuple[int, int]: |
|
"""Convert a bounding box to a pair of patch indices. |
|
|
|
Args: |
|
bbox (`Tuple[float, float, float, float]`): |
|
The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left |
|
and lower-right corners of the box. It should have x2 > x1 and y1 > y2. |
|
num_patches_per_side (`int`): the number of patches along each side. |
|
|
|
Returns: |
|
`Tuple[int, int]`: A pair of patch indices. |
|
""" |
|
(x1, y1, x2, y2) = bbox |
|
|
|
ul_x = math.floor(x1 * num_patches_per_side) |
|
ul_y = math.floor(y1 * num_patches_per_side) |
|
|
|
lr_x = math.ceil(x2 * num_patches_per_side - 1) |
|
lr_y = math.ceil(y2 * num_patches_per_side - 1) |
|
|
|
ul_idx = ul_y * num_patches_per_side + ul_x |
|
lr_idx = lr_y * num_patches_per_side + lr_x |
|
|
|
return ul_idx, lr_idx |
|
|
|
|
|
|
|
|
|
def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int): |
|
""" |
|
Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a |
|
bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). |
|
|
|
Args: |
|
ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box. |
|
lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box. |
|
num_patches_per_side (`int`): the number of patches along each side. |
|
|
|
Returns: |
|
`Tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). |
|
""" |
|
|
|
cell_size = 1.0 / num_patches_per_side |
|
|
|
|
|
ul_x = ul_idx % num_patches_per_side |
|
ul_y = ul_idx // num_patches_per_side |
|
|
|
lr_x = lr_idx % num_patches_per_side |
|
lr_y = lr_idx // num_patches_per_side |
|
|
|
|
|
if ul_idx == lr_idx: |
|
x1 = ul_x * cell_size |
|
y1 = ul_y * cell_size |
|
x2 = lr_x * cell_size + cell_size |
|
y2 = lr_y * cell_size + cell_size |
|
elif ul_x == lr_x or ul_y == lr_y: |
|
x1 = ul_x * cell_size |
|
y1 = ul_y * cell_size |
|
x2 = lr_x * cell_size + cell_size |
|
y2 = lr_y * cell_size + cell_size |
|
else: |
|
x1 = ul_x * cell_size + cell_size / 2 |
|
y1 = ul_y * cell_size + cell_size / 2 |
|
x2 = lr_x * cell_size + cell_size / 2 |
|
y2 = lr_y * cell_size + cell_size / 2 |
|
|
|
return x1, y1, x2, y2 |
|
|
|
|
|
|
|
|
|
def extract_entities_with_patch_indices(text): |
|
|
|
pattern = r'(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>' |
|
|
|
|
|
matches = re.finditer(pattern, text) |
|
|
|
|
|
entities_with_patch_indices = [] |
|
|
|
for match in matches: |
|
|
|
span = match.span(2) |
|
phrase_tag, phrase, match_content = match.groups() |
|
if not phrase_tag: |
|
phrase = None |
|
span = (None, None) |
|
|
|
|
|
patch_index_pairs = match_content.split('</delimiter_of_multi_objects/>') |
|
|
|
entity_bboxes = [] |
|
for pair in patch_index_pairs: |
|
|
|
x = re.search(r'<patch_index_(\d+)>', pair) |
|
y = re.search(r'<patch_index_(\d+)>', pair[1:]) |
|
|
|
if x and y: |
|
if phrase: |
|
entity_bboxes.append((int(x.group(1)), int(y.group(1)))) |
|
else: |
|
entity_bboxes.append((int(x.group(1)), int(y.group(1)))) |
|
|
|
if phrase: |
|
entities_with_patch_indices.append((phrase, span, entity_bboxes)) |
|
else: |
|
for bbox in entity_bboxes: |
|
|
|
entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>" |
|
entities_with_patch_indices.append((entity, span, [bbox])) |
|
|
|
return entities_with_patch_indices |
|
|
|
|
|
|
|
def remove_special_fields(text): |
|
return re.sub('<.*?>', '', text) |
|
|
|
|
|
def adjust_entity_positions(entity, text): |
|
|
|
entity_name, (start, end) = entity |
|
adjusted_start = len(remove_special_fields(text[:start])) |
|
adjusted_end = len(remove_special_fields(text[:end])) |
|
adjusted_entity = (entity_name, (adjusted_start, adjusted_end)) |
|
return adjusted_entity |
|
|
|
|
|
|
|
|
|
def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32): |
|
|
|
processed_text = remove_special_fields(text) |
|
|
|
entities_with_patch_indices = extract_entities_with_patch_indices(text) |
|
entities = [] |
|
for item in entities_with_patch_indices: |
|
entity, bboxes = item[0:2], item[2] |
|
adjusted_entity = adjust_entity_positions(entity, text) |
|
bboxes_in_coords = list(map(lambda bbox: patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side), bboxes)) |
|
|
|
entities.append(adjusted_entity + (bboxes_in_coords,)) |
|
|
|
def cleanup_spaces(text, entities): |
|
new_text = text.strip() |
|
leading_spaces = len(text) - len(text.lstrip()) |
|
|
|
new_entities = [] |
|
for entity_name, (start, end), bboxes in entities: |
|
|
|
entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip()) |
|
entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip()) |
|
|
|
start = start - leading_spaces + entity_name_leading_spaces |
|
end = end - leading_spaces - entity_name_trailing_spaces |
|
entity_name = entity_name.strip() |
|
|
|
new_entities.append((entity_name, (start, end), bboxes)) |
|
|
|
return new_text, new_entities |
|
|
|
return cleanup_spaces(processed_text, entities) |
|
|