""" Processor class for Molmo. """ from typing import Optional import PIL from PIL import Image try: from typing import Unpack except ImportError: from typing_extensions import Unpack import re from typing import List, Optional, Union import numpy as np import torch import torchvision.transforms.functional as F from transformers import AutoTokenizer from transformers.image_utils import ImageInput from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, TextKwargs) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging logger = logging.get_logger(__name__) IGNORE_INDEX = -100 DEFAULT_PAD_TOKEN_INDEX = 0 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" # For Objects DEFAULT_OBJECT_TOKEN = ">" DEFAULT_OBJECT_FEATURE_TOKEN = "" DEFAULT_OBJECT_INDEX = -300 # For Grounding DEFAULT_GROUNDING_START = "" DEFAULT_GROUNDING_END = "" DEFAULT_GROUNDING_OBJECTS_START = "" DEFAULT_GROUNDING_OBJECTS_END = "" def xyxy_to_xywh(boxes): """ Convert boxes from xywh to xyxy format. Parameters: boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. Each box is represented as [x, y, x, y]. Returns: numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h]. """ boxes = np.array(boxes) x_min, y_min, x_max, y_max = ( boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3], ) w = x_max - x_min h = y_max - y_min return np.stack([x_min, y_min, w, h], axis=1) def xywh_to_xyxy(boxes): """ Convert boxes from xywh to xyxy format. Parameters: boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. Each box is represented as [x, y, width, height]. Returns: numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max]. """ boxes = np.array(boxes) x, y, width, height = ( boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3], ) x_max = x + width y_max = y + height return np.stack([x, y, x_max, y_max], axis=1) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def pad_boxes(gt_boxes, old_size): old_w, old_h = old_size gt_boxes = np.array(gt_boxes).astype(np.float32) # Calculate the padding added if old_w > old_h: pad_top = (old_w - old_h) // 2 pad_bottom = old_w - old_h - pad_top pad_left, pad_right = 0, 0 else: pad_left = (old_h - old_w) // 2 pad_right = old_h - old_w - pad_left pad_top, pad_bottom = 0, 0 # Adjust the boxes for padding gt_boxes[:, 0] += pad_left # x gt_boxes[:, 1] += pad_top # y return gt_boxes def resize_boxes(gt_boxes, old_size, new_size): old_w, old_h = old_size new_h, new_w = new_size gt_boxes = np.array(gt_boxes).astype(np.float32) # Calculate scale factors scale_x = new_w / max(old_w, old_h) scale_y = new_h / max(old_w, old_h) # Resize the boxes gt_boxes[:, 0] *= scale_x # x gt_boxes[:, 1] *= scale_y # y gt_boxes[:, 2] *= scale_x # w gt_boxes[:, 3] *= scale_y # h return gt_boxes def split_special_strings(input_string: str, special_strings: list[str] = None): """Split the input string into a list of strings, keeping the special strings. Args: input_string (str): The input string to split. Example: input_string = "\n\n I am happy today." output = ['', '\n', '', '', '', '\n I am happy today.'] Returns: list: A list of strings, with the special strings separated from the rest of the input string. """ # Create a regex pattern to match the special strings pattern = "|".join(map(re.escape, special_strings)) # Split the input string using the pattern, keeping the special strings in the result split_list = re.split(f"({pattern})", input_string) # Remove empty strings from the list split_list = [s for s in split_list if s] return split_list def tokenizer_image_object_token(prompt, tokenizer): bos_token_id = tokenizer.bos_token_id split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN] chunks = split_special_strings(prompt, split_tokens) input_encode = [bos_token_id] for chunk in chunks: if chunk == DEFAULT_IMAGE_TOKEN: input_encode.append(IMAGE_TOKEN_INDEX) elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN: input_encode.append(DEFAULT_OBJECT_INDEX) else: input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False)) return input_encode class ChatRexProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, image_processor = None, tokenizer : AutoTokenizer = None, **kwargs): # self.image_processor = image_processor # self.tokenizer = tokenizer super().__init__(image_processor, tokenizer) self._special_tokens = None self.template = dict( SYSTEM=('A chat between a curious user and an artificial ' 'intelligence assistant. The assistant gives ' 'helpful, detailed, and polite answers to the ' 'user\'s questions. {system}\n '), INSTRUCTION=('USER: {input} ASSISTANT:'), SEP='\n') def process( self, image: Union[str, Image.Image], bbox: List[List[int]], question: str, ): """Prepare input data for inference. Args: image (Union[str, Image.Image]): The image to process. bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should be in order of [x, y, x , y]. question (str): The question to ask about the image. """ data_dict = {} # step1 load image if type(image) == str: image = Image.open(image).convert("RGB") ori_w, ori_h = F.get_image_size(image) image = expand2square( image, tuple(int(x * 255) for x in self.image_processor.image_mean), ) pad_w, pad_h = F.get_image_size(image) image_aux = self.image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] resize_h, resize_w = image_aux.shape[-2:] data_dict["pixel_values_aux"] = image_aux.unsqueeze(0) image = image_aux.clone() image = torch.nn.functional.interpolate( image[None], size=[336, 336], mode="bilinear", align_corners=False, )[0] data_dict["pixel_values"] = image.unsqueeze(0) # step2 load boxes bbox= xyxy_to_xywh(bbox) bbox = pad_boxes(bbox, (ori_w, ori_h)) bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w)) data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0) # step3 prepare question total_num_boxes = len(bbox) obj_tokens = [ DEFAULT_OBJECT_TOKEN.replace("", str(i)) for i in range(total_num_boxes) ] obj_tokens = ( DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN ) question = question.replace(DEFAULT_IMAGE_TOKEN, "") question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question inputs = "" inputs += self.template["INSTRUCTION"].format(input=question, round=1) # step4 tokenize question input_ids = tokenizer_image_object_token(inputs, self.tokenizer) data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0) return data_dict ChatRexProcessor.register_for_auto_class()