|
""" |
|
Processor class for Molmo. |
|
""" |
|
|
|
from typing import Optional |
|
|
|
import PIL |
|
from PIL import Image |
|
|
|
try: |
|
from typing import Unpack |
|
except ImportError: |
|
from typing_extensions import Unpack |
|
|
|
import re |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision.transforms.functional as F |
|
from transformers import AutoTokenizer |
|
from transformers.image_utils import ImageInput |
|
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, |
|
TextKwargs) |
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN_INDEX = 0 |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
|
|
|
DEFAULT_OBJECT_TOKEN = "<obj<i>>" |
|
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" |
|
DEFAULT_OBJECT_INDEX = -300 |
|
|
|
|
|
DEFAULT_GROUNDING_START = "<ground>" |
|
DEFAULT_GROUNDING_END = "</ground>" |
|
DEFAULT_GROUNDING_OBJECTS_START = "<objects>" |
|
DEFAULT_GROUNDING_OBJECTS_END = "</objects>" |
|
|
|
def xyxy_to_xywh(boxes): |
|
""" |
|
Convert boxes from xywh to xyxy format. |
|
|
|
Parameters: |
|
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. |
|
Each box is represented as [x, y, x, y]. |
|
|
|
Returns: |
|
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h]. |
|
""" |
|
boxes = np.array(boxes) |
|
x_min, y_min, x_max, y_max = ( |
|
boxes[:, 0], |
|
boxes[:, 1], |
|
boxes[:, 2], |
|
boxes[:, 3], |
|
) |
|
w = x_max - x_min |
|
h = y_max - y_min |
|
return np.stack([x_min, y_min, w, h], axis=1) |
|
|
|
|
|
def xywh_to_xyxy(boxes): |
|
""" |
|
Convert boxes from xywh to xyxy format. |
|
|
|
Parameters: |
|
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. |
|
Each box is represented as [x, y, width, height]. |
|
|
|
Returns: |
|
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max]. |
|
""" |
|
boxes = np.array(boxes) |
|
x, y, width, height = ( |
|
boxes[:, 0], |
|
boxes[:, 1], |
|
boxes[:, 2], |
|
boxes[:, 3], |
|
) |
|
x_max = x + width |
|
y_max = y + height |
|
return np.stack([x, y, x_max, y_max], axis=1) |
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
def pad_boxes(gt_boxes, old_size): |
|
old_w, old_h = old_size |
|
gt_boxes = np.array(gt_boxes).astype(np.float32) |
|
|
|
if old_w > old_h: |
|
pad_top = (old_w - old_h) // 2 |
|
pad_bottom = old_w - old_h - pad_top |
|
pad_left, pad_right = 0, 0 |
|
else: |
|
pad_left = (old_h - old_w) // 2 |
|
pad_right = old_h - old_w - pad_left |
|
pad_top, pad_bottom = 0, 0 |
|
|
|
|
|
gt_boxes[:, 0] += pad_left |
|
gt_boxes[:, 1] += pad_top |
|
return gt_boxes |
|
|
|
|
|
def resize_boxes(gt_boxes, old_size, new_size): |
|
old_w, old_h = old_size |
|
new_h, new_w = new_size |
|
gt_boxes = np.array(gt_boxes).astype(np.float32) |
|
|
|
scale_x = new_w / max(old_w, old_h) |
|
scale_y = new_h / max(old_w, old_h) |
|
|
|
|
|
gt_boxes[:, 0] *= scale_x |
|
gt_boxes[:, 1] *= scale_y |
|
gt_boxes[:, 2] *= scale_x |
|
gt_boxes[:, 3] *= scale_y |
|
|
|
return gt_boxes |
|
|
|
def split_special_strings(input_string: str, special_strings: list[str] = None): |
|
"""Split the input string into a list of strings, keeping the special strings. |
|
|
|
Args: |
|
input_string (str): The input string to split. |
|
|
|
Example: |
|
|
|
input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today." |
|
output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.'] |
|
|
|
Returns: |
|
list: A list of strings, with the special strings separated from the rest of the input string. |
|
""" |
|
|
|
pattern = "|".join(map(re.escape, special_strings)) |
|
|
|
|
|
split_list = re.split(f"({pattern})", input_string) |
|
|
|
|
|
split_list = [s for s in split_list if s] |
|
|
|
return split_list |
|
|
|
def tokenizer_image_object_token(prompt, tokenizer): |
|
bos_token_id = tokenizer.bos_token_id |
|
split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN] |
|
chunks = split_special_strings(prompt, split_tokens) |
|
input_encode = [bos_token_id] |
|
for chunk in chunks: |
|
if chunk == DEFAULT_IMAGE_TOKEN: |
|
input_encode.append(IMAGE_TOKEN_INDEX) |
|
elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN: |
|
input_encode.append(DEFAULT_OBJECT_INDEX) |
|
else: |
|
input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False)) |
|
return input_encode |
|
|
|
class ChatRexProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "tokenizer"] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, image_processor = None, tokenizer : AutoTokenizer = None, **kwargs): |
|
|
|
|
|
super().__init__(image_processor, tokenizer) |
|
self._special_tokens = None |
|
self.template = dict( |
|
SYSTEM=('A chat between a curious user and an artificial ' |
|
'intelligence assistant. The assistant gives ' |
|
'helpful, detailed, and polite answers to the ' |
|
'user\'s questions. {system}\n '), |
|
INSTRUCTION=('USER: {input} ASSISTANT:'), |
|
SEP='\n') |
|
|
|
def process( |
|
self, |
|
image: Union[str, Image.Image], |
|
bbox: List[List[int]], |
|
question: str, |
|
): |
|
"""Prepare input data for inference. |
|
|
|
Args: |
|
image (Union[str, Image.Image]): The image to process. |
|
bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should |
|
be in order of [x, y, x , y]. |
|
question (str): The question to ask about the image. |
|
""" |
|
data_dict = {} |
|
|
|
if type(image) == str: |
|
image = Image.open(image).convert("RGB") |
|
ori_w, ori_h = F.get_image_size(image) |
|
image = expand2square( |
|
image, |
|
tuple(int(x * 255) for x in self.image_processor.image_mean), |
|
) |
|
pad_w, pad_h = F.get_image_size(image) |
|
image_aux = self.image_processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values" |
|
][0] |
|
resize_h, resize_w = image_aux.shape[-2:] |
|
data_dict["pixel_values_aux"] = image_aux.unsqueeze(0) |
|
image = image_aux.clone() |
|
image = torch.nn.functional.interpolate( |
|
image[None], |
|
size=[336, 336], |
|
mode="bilinear", |
|
align_corners=False, |
|
)[0] |
|
data_dict["pixel_values"] = image.unsqueeze(0) |
|
|
|
|
|
bbox= xyxy_to_xywh(bbox) |
|
bbox = pad_boxes(bbox, (ori_w, ori_h)) |
|
bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w)) |
|
data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0) |
|
|
|
|
|
total_num_boxes = len(bbox) |
|
obj_tokens = [ |
|
DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes) |
|
] |
|
obj_tokens = ( |
|
DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN |
|
) |
|
question = question.replace(DEFAULT_IMAGE_TOKEN, "") |
|
question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question |
|
|
|
|
|
inputs = "" |
|
inputs += self.template["INSTRUCTION"].format(input=question, round=1) |
|
|
|
|
|
input_ids = tokenizer_image_object_token(inputs, self.tokenizer) |
|
data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0) |
|
|
|
return data_dict |
|
|
|
ChatRexProcessor.register_for_auto_class() |