|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for InfiMMHD. |
|
""" |
|
|
|
import random |
|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
import torchvision.transforms.functional as F |
|
from PIL import Image |
|
from torchvision.transforms import ( |
|
CenterCrop, |
|
Compose, |
|
InterpolationMode, |
|
Normalize, |
|
Resize, |
|
ToTensor, |
|
) |
|
|
|
from transformers import AutoTokenizer |
|
from transformers.image_processing_utils import ImageProcessingMixin |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
IMAGE_TOKEN = "<image>" |
|
END_OF_CHUNK_TOKEN = "<|endofchunk|>" |
|
PAD_TOKEN = "<PAD>" |
|
|
|
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
class ResizeKeepRatio: |
|
"""Resize and Keep Ratio |
|
|
|
Copy & paste from `timm` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
size, |
|
longest=0.0, |
|
interpolation=InterpolationMode.BICUBIC, |
|
random_scale_prob=0.0, |
|
random_scale_range=(0.85, 1.05), |
|
random_aspect_prob=0.0, |
|
random_aspect_range=(0.9, 1.11), |
|
): |
|
if isinstance(size, (list, tuple)): |
|
self.size = tuple(size) |
|
else: |
|
self.size = (size, size) |
|
self.interpolation = interpolation |
|
self.longest = float(longest) |
|
self.random_scale_prob = random_scale_prob |
|
self.random_scale_range = random_scale_range |
|
self.random_aspect_prob = random_aspect_prob |
|
self.random_aspect_range = random_aspect_range |
|
|
|
@staticmethod |
|
def get_params( |
|
img, |
|
target_size, |
|
longest, |
|
random_scale_prob=0.0, |
|
random_scale_range=(0.85, 1.05), |
|
random_aspect_prob=0.0, |
|
random_aspect_range=(0.9, 1.11), |
|
): |
|
"""Get parameters""" |
|
source_size = img.size[::-1] |
|
h, w = source_size |
|
target_h, target_w = target_size |
|
ratio_h = h / target_h |
|
ratio_w = w / target_w |
|
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * ( |
|
1.0 - longest |
|
) |
|
if random_scale_prob > 0 and random.random() < random_scale_prob: |
|
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) |
|
ratio_factor = (ratio_factor, ratio_factor) |
|
else: |
|
ratio_factor = (1.0, 1.0) |
|
if random_aspect_prob > 0 and random.random() < random_aspect_prob: |
|
aspect_factor = random.uniform( |
|
random_aspect_range[0], random_aspect_range[1] |
|
) |
|
ratio_factor = ( |
|
ratio_factor[0] / aspect_factor, |
|
ratio_factor[1] * aspect_factor, |
|
) |
|
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] |
|
return size |
|
|
|
def __call__(self, img): |
|
""" |
|
Args: |
|
img (PIL Image): Image to be cropped and resized. |
|
|
|
Returns: |
|
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size |
|
""" |
|
size = self.get_params( |
|
img, |
|
self.size, |
|
self.longest, |
|
self.random_scale_prob, |
|
self.random_scale_range, |
|
self.random_aspect_prob, |
|
self.random_aspect_range, |
|
) |
|
img = F.resize(img, size, self.interpolation) |
|
return img |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + "(size={0}".format(self.size) |
|
format_string += f", interpolation={self.interpolation})" |
|
format_string += f", longest={self.longest:.3f})" |
|
return format_string |
|
|
|
|
|
def image_transform( |
|
image_size: Union[int, Tuple[int, int]], |
|
mean: Optional[Tuple[float, ...]] = None, |
|
std: Optional[Tuple[float, ...]] = None, |
|
resize_mode: Optional[str] = None, |
|
interpolation: Optional[str] = None, |
|
): |
|
mean = mean or OPENAI_DATASET_MEAN |
|
if not isinstance(mean, (list, tuple)): |
|
mean = (mean,) * 3 |
|
|
|
std = std or OPENAI_DATASET_STD |
|
if not isinstance(std, (list, tuple)): |
|
std = (std,) * 3 |
|
|
|
interpolation = interpolation or "bicubic" |
|
assert interpolation in ["bicubic", "bilinear", "random"] |
|
|
|
interpolation_mode = ( |
|
InterpolationMode.BILINEAR |
|
if interpolation == "bilinear" |
|
else InterpolationMode.BICUBIC |
|
) |
|
|
|
resize_mode = resize_mode or "shortest" |
|
assert resize_mode in ("shortest", "longest", "squash") |
|
|
|
normalize = Normalize(mean=mean, std=std) |
|
|
|
assert resize_mode == "shortest" |
|
if not isinstance(image_size, (tuple, list)): |
|
image_size = (image_size, image_size) |
|
if image_size[0] == image_size[1]: |
|
|
|
transforms = [Resize(image_size[0], interpolation=interpolation_mode)] |
|
else: |
|
|
|
transforms = [ResizeKeepRatio(image_size)] |
|
transforms += [CenterCrop(image_size)] |
|
|
|
transforms.extend( |
|
[ |
|
_convert_to_rgb, |
|
ToTensor(), |
|
normalize, |
|
] |
|
) |
|
return Compose(transforms) |
|
|
|
|
|
def get_target_size(width, height, max_image_size, min_image_size): |
|
target_width = 0 |
|
target_height = 0 |
|
if width < min_image_size: |
|
target_width = min_image_size |
|
elif width > max_image_size: |
|
target_width = max_image_size |
|
|
|
if height < min_image_size: |
|
target_height = min_image_size |
|
elif height > max_image_size: |
|
target_height = max_image_size |
|
|
|
if target_width == 0: |
|
ratio = ((width - min_image_size) + int(0.5*min_image_size))//min_image_size |
|
target_width = ratio * min_image_size + min_image_size |
|
|
|
if target_height == 0: |
|
ratio = ((height - min_image_size) + int(0.5*min_image_size))//min_image_size |
|
target_height = ratio * min_image_size + min_image_size |
|
|
|
return target_width, target_height |
|
|
|
class EVAClipImageProcessor(ImageProcessingMixin): |
|
def __init__(self, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.image_processor = image_transform(image_size=448) |
|
self.img_size = 448 |
|
|
|
def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: |
|
""" |
|
Convert images to tensors, reshape them, and stack them. |
|
Args: |
|
batch: A list of lists of images. |
|
Returns: |
|
preprocessed images (tensors) or None |
|
shape (B, T_img, F, C, H, W) |
|
None if no images in batch |
|
""" |
|
|
|
target_image_num = [] |
|
target_shape = [] |
|
for x in batch: |
|
width, height = x[0].size |
|
tar_wid, tar_hei = get_target_size(width, height, 1344, self.img_size) |
|
target_shape.append((tar_wid, tar_hei)) |
|
target_image_num.append(int(tar_wid/self.img_size*tar_hei/self.img_size)) |
|
|
|
images_per_example = max(target_image_num) |
|
batch_images = None |
|
image_mask = None |
|
sub_image_shape = None |
|
for iexample, example in enumerate(batch): |
|
for img in example: |
|
img_ori = img |
|
tar_wid, tar_hei = target_shape[iexample] |
|
img_new = img.resize((tar_wid, tar_hei), Image.BILINEAR) |
|
sub_images = [img_ori] |
|
|
|
for y in range(0, tar_hei, self.img_size): |
|
for x in range(0, tar_wid, self.img_size): |
|
sub_img = img_new.crop((x, y, x + self.img_size, y + self.img_size)) |
|
sub_images.append(sub_img) |
|
|
|
for iimage, image in enumerate(sub_images): |
|
preprocessed = self.image_processor(image) |
|
if batch_images is None: |
|
batch_images = torch.zeros( |
|
(len(batch), images_per_example+1, 1) + preprocessed.shape, |
|
dtype=preprocessed.dtype, |
|
) |
|
batch_images[iexample, iimage, 0] = preprocessed |
|
if not torch.is_tensor(image_mask): |
|
image_mask = torch.zeros((len(batch), images_per_example+1), dtype=preprocessed.dtype) |
|
image_mask[iexample,:target_image_num[iexample]+1] = 1.0 |
|
if not torch.is_tensor(sub_image_shape): |
|
sub_image_shape = torch.zeros((len(batch), 2), dtype=preprocessed.dtype) |
|
sub_image_shape[iexample, 0], sub_image_shape[iexample, 1] = tar_wid/self.img_size, tar_hei/self.img_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return batch_images, image_mask, sub_image_shape |
|
|
|
def preprocess(self, imgpaths=None): |
|
if imgpaths is None or len(imgpaths) == 0: |
|
images = [(Image.new("RGB", (224, 224), color="black"))] |
|
else: |
|
images = [Image.open(fp) for fp in imgpaths] |
|
return self._prepare_images([images]) |
|
|
|
|
|
class InfiMMHDProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a InfiMMLlama2 processor which wraps a tokenizer and an image processor into a single processor. |
|
|
|
Args: |
|
image_processor (`EVAClipImageProcessor`): |
|
An instance of [`EVAClipImageProcessor`]. The image processor is a required input. |
|
tokenizer (`LlamaTokenizer`): |
|
An instance of [`LlamaTokenizer`]. The tokenizer is a required input. |
|
image_size (`int`, *optional*, defaults to 336): Image size (assuming a square image) |
|
""" |
|
|
|
attributes = ["tokenizer"] |
|
tokenizer_class = "LlamaTokenizer" |
|
|
|
def __init__(self, tokenizer=None, **kwargs): |
|
self.image_processor = EVAClipImageProcessor() |
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained("infimm-hd", verbose=False) |
|
|
|
super().__init__(tokenizer, tokenizer) |
|
|
|
def _prepare_text( |
|
self, |
|
batch: List[List[str]], |
|
padding="longest", |
|
truncation=True, |
|
max_length=2048, |
|
): |
|
""" |
|
Tokenize the text and stack them. |
|
Args: |
|
batch: A list of lists of strings. |
|
Returns: |
|
input_ids (tensor) |
|
shape (B, T_txt) |
|
attention_mask (tensor) |
|
shape (B, T_txt) |
|
""" |
|
batch = [b.strip() for b in batch] |
|
encodings = self.tokenizer( |
|
batch, |
|
padding=padding, |
|
truncation=truncation, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
) |
|
input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] |
|
|
|
return input_ids, attention_mask |
|
|
|
def __call__( |
|
self, |
|
prompts, |
|
) -> BatchEncoding: |
|
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that |
|
the model was trained on and prepares the image pixel values for the model to process. |
|
""" |
|
image_paths = self._extract_image_paths(prompts) |
|
images, image_mask, sub_image_shape = self.image_processor.preprocess(image_paths) |
|
prompts = self._replace_with_media_tokens(prompts) |
|
final_prompt = self.apply_template(prompts) |
|
|
|
|
|
input_ids, attention_mask = self._prepare_text([final_prompt]) |
|
return BatchEncoding( |
|
data={ |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"batch_images": images, |
|
"image_mask": image_mask, |
|
"subimage_shape": sub_image_shape, |
|
} |
|
) |
|
|
|
def _extract_image_paths(self, prompts): |
|
image_paths = [] |
|
for round in prompts: |
|
if round["role"] != "user": |
|
continue |
|
for piece in round["content"]: |
|
if isinstance(piece, dict): |
|
image_paths.append(piece["image"]) |
|
return image_paths |
|
|
|
def _replace_with_media_tokens(self, prompts): |
|
new_prompts = [] |
|
is_first_img = True |
|
for round in prompts: |
|
if round["role"] != "user": |
|
new_prompts.append(round) |
|
new_content = [] |
|
for piece in round["content"]: |
|
if isinstance(piece, dict): |
|
new_content.append( |
|
f"{IMAGE_TOKEN}" if is_first_img |
|
else f"{END_OF_CHUNK_TOKEN}{IMAGE_TOKEN}" |
|
) |
|
is_first_img = False |
|
else: |
|
new_content.append(piece) |
|
new_prompts.append({"role": "user", "content": "".join(new_content)}) |
|
return new_prompts |
|
|
|
def apply_template(self, messages, task="generation"): |
|
prompt = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True if task == "generation" else False, |
|
) |
|
return prompt |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|