LinCIR / data_utils.py
Geonmo's picture
initial commit
cacafc1
from pathlib import Path
import PIL
import torch
import torchvision.transforms.functional as FT
from torch.utils.data import Dataset
from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize
from torchvision.transforms import InterpolationMode
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def collate_fn(batch):
'''
function which discard None images in a batch when using torch DataLoader
:param batch: input_batch
:return: output_batch = input_batch - None_values
'''
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
class TargetPad:
"""
If an image aspect ratio is above a target ratio, pad the image to match such target ratio.
For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022).
"""
def __init__(self, target_ratio: float, size: int):
"""
:param target_ratio: target ratio
:param size: preprocessing output dimension
"""
self.size = size
self.target_ratio = target_ratio
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
actual_ratio = max(w, h) / min(w, h)
if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio
return image
scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio
hp = max(int((scaled_max_wh - w) / 2), 0)
vp = max(int((scaled_max_wh - h) / 2), 0)
padding = [hp, vp, hp, vp]
return FT.pad(image, padding, 0, 'constant')
def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor:
"""
CLIP-like preprocessing transform computed after using TargetPad pad
:param target_ratio: target ratio for TargetPad
:param dim: image output dimension
:return: CLIP-like torchvision Compose transform
"""
return Compose([
TargetPad(target_ratio, dim),
Resize(dim, interpolation=InterpolationMode.BICUBIC),
CenterCrop(dim),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])