Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
from typing import Callable, List, Tuple, Union | |
from PIL import Image | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
class RetrievalTrainingDataset(Dataset): | |
""" | |
Create the training dataset for Retrieval task. | |
Args: | |
ann_file (List[str]): The paths to training annotation json files. | |
image_root (str): The path to image data directory. | |
image_transform (Callable[[Image.Image], Tensor]): Image data transform. | |
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform. | |
Dataset Outputs: | |
image (Tensor): Transformed image input tensor of shape (C, H, W). | |
caption (Tensor): Transformed text token input ids. | |
idx (int): The unique identifier for the image. | |
""" | |
def __init__( | |
self, | |
ann_file: List[str], | |
image_root: str, | |
image_transform: Callable[[Image.Image], Tensor], | |
text_transform: Callable[[Union[List[str], str]], Tensor], | |
) -> None: | |
self.ann = [] | |
for f in ann_file: | |
self.ann += json.load(open(f, "r")) | |
self.image_root = image_root | |
self.image_transform = image_transform | |
self.text_transform = text_transform | |
self.idx = {} # map str image_id from dataset to int ids | |
i = 0 | |
for ann in self.ann: | |
image_id = ann["image_id"] | |
if image_id not in self.idx.keys(): | |
self.idx[image_id] = i | |
i += 1 | |
def __len__(self) -> int: | |
return len(self.ann) | |
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]: | |
ann = self.ann[index] | |
image_path = os.path.join(self.image_root, ann["image"]) | |
image = Image.open(image_path).convert("RGB") | |
image = self.image_transform(image) | |
caption = self.text_transform(ann["caption"]) | |
return image, caption, self.idx[ann["image_id"]] | |
class ImageToTextRetrievalDataset(Dataset): | |
""" | |
Create the dataset for Image-to-Text Retrieval task. | |
Args: | |
ann_file (List[str]): The paths to annotation json files. | |
image_root (str): The path to image data directory. | |
image_transform (Callable[[Image.Image], Tensor]): Image data transform. | |
Dataset Outputs: | |
image (Tensor): Transformed image input tensor of shape (C, H, W). | |
""" | |
def __init__( | |
self, | |
ann_file: List[str], | |
image_root: str, | |
image_transform: Callable[[Image.Image], Tensor], | |
) -> None: | |
self.image_root = image_root | |
self.image_transform = image_transform | |
self.ann = [] | |
self.images = [] # paths to all images in the dataset | |
self.image_to_text = {} # map image ids to text ids for evaluation | |
for f in ann_file: | |
self.ann += json.load(open(f, "r")) | |
text_id = 0 | |
for image_id, ann in enumerate(self.ann): | |
self.images.append(ann["image"]) | |
num_text = len(ann["caption"]) | |
self.image_to_text[image_id] = list(range(text_id, text_id + num_text)) | |
text_id += num_text | |
def __len__(self) -> int: | |
return len(self.images) | |
def __getitem__(self, index: int) -> Tensor: | |
image_path = os.path.join(self.image_root, self.images[index]) | |
image = Image.open(image_path).convert("RGB") | |
image = self.image_transform(image) | |
return image | |
class TextToImageRetrievalDataset(Dataset): | |
""" | |
Create the dataset for Text-to-Image Retrieval task. | |
Args: | |
ann_file (List[str]): The paths to annotation json files. | |
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform. | |
Dataset Outputs: | |
text (Tensor): Transformed text token input ids. | |
""" | |
def __init__( | |
self, | |
ann_file: List[str], | |
text_transform: Callable[[Union[List[str], str]], Tensor], | |
) -> None: | |
self.text_transform = text_transform | |
self.ann = [] | |
self.text = [] # all text strings in the dataset | |
self.text_to_image = {} # map text ids to image ids for evaluation | |
for f in ann_file: | |
self.ann += json.load(open(f, "r")) | |
text_id = 0 | |
for image_id, ann in enumerate(self.ann): | |
for caption in ann["caption"]: | |
self.text.append(caption) | |
self.text_to_image[text_id] = image_id | |
text_id += 1 | |
def __len__(self) -> int: | |
return len(self.text) | |
def __getitem__(self, index: int) -> Tensor: | |
text = self.text_transform(self.text[index]) | |
return text | |