Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import PIL | |
import torch | |
import torchvision.transforms.functional as FT | |
from torch.utils.data import Dataset | |
from torchvision.transforms import Compose, CenterCrop, ToTensor, Normalize, Resize | |
from torchvision.transforms import InterpolationMode | |
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute() | |
def _convert_image_to_rgb(image): | |
return image.convert("RGB") | |
def collate_fn(batch): | |
''' | |
function which discard None images in a batch when using torch DataLoader | |
:param batch: input_batch | |
:return: output_batch = input_batch - None_values | |
''' | |
batch = list(filter(lambda x: x is not None, batch)) | |
return torch.utils.data.dataloader.default_collate(batch) | |
class TargetPad: | |
""" | |
If an image aspect ratio is above a target ratio, pad the image to match such target ratio. | |
For more details see Baldrati et al. 'Effective conditioned and composed image retrieval combining clip-based features.' Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2022). | |
""" | |
def __init__(self, target_ratio: float, size: int): | |
""" | |
:param target_ratio: target ratio | |
:param size: preprocessing output dimension | |
""" | |
self.size = size | |
self.target_ratio = target_ratio | |
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image: | |
w, h = image.size | |
actual_ratio = max(w, h) / min(w, h) | |
if actual_ratio < self.target_ratio: # check if the ratio is above or below the target ratio | |
return image | |
scaled_max_wh = max(w, h) / self.target_ratio # rescale the pad to match the target ratio | |
hp = max(int((scaled_max_wh - w) / 2), 0) | |
vp = max(int((scaled_max_wh - h) / 2), 0) | |
padding = [hp, vp, hp, vp] | |
return FT.pad(image, padding, 0, 'constant') | |
def targetpad_transform(target_ratio: float, dim: int) -> torch.Tensor: | |
""" | |
CLIP-like preprocessing transform computed after using TargetPad pad | |
:param target_ratio: target ratio for TargetPad | |
:param dim: image output dimension | |
:return: CLIP-like torchvision Compose transform | |
""" | |
return Compose([ | |
TargetPad(target_ratio, dim), | |
Resize(dim, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(dim), | |
_convert_image_to_rgb, | |
ToTensor(), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |