Ngaima Sandiman
Explicitly set dtype.
c201cb1
raw
history blame
2.92 kB
from typing import List, Optional
from PIL import Image
import numpy as np
import torch
from src.utils.util import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
add_image_tokens_to_prompt,
process_images,
)
from transformers import SiglipImageProcessor
class ImageCraftProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024)
] # These tokens are used for object detection (bounding boxes)
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # These tokens are used for object segmentation
tokenizer.add_tokens(EXTRA_TOKENS)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# We will add the BOS and EOS tokens ourselves
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
def __call__(
self,
text: List[str],
images: List[Image.Image],
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert (
len(images) == 1 and len(text) == 1
), f"Received {len(images)} images for {len(text)} prompts."
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC,
rescale_factor=1 / 255.0,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
)
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
pixel_values = np.stack(pixel_values, axis=0)
# Convert the numpy array to a PyTorch tensor
pixel_values = torch.tensor(pixel_values, dtype=torch.float16)
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_length=self.image_seq_length,
image_token=self.IMAGE_TOKEN,
)
for prompt in text
]
# max_length += self.image_seq_length
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
max_length=512,
truncation=truncation,
)
return_data = {"pixel_values": pixel_values, **inputs}
return return_data