albef-vqa / data /retrieval_dataset.py
ryanramos's picture
Add source code
d1b8c9b
raw
history blame
4.85 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from typing import Callable, List, Tuple, Union
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
class RetrievalTrainingDataset(Dataset):
"""
Create the training dataset for Retrieval task.
Args:
ann_file (List[str]): The paths to training annotation json files.
image_root (str): The path to image data directory.
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
Dataset Outputs:
image (Tensor): Transformed image input tensor of shape (C, H, W).
caption (Tensor): Transformed text token input ids.
idx (int): The unique identifier for the image.
"""
def __init__(
self,
ann_file: List[str],
image_root: str,
image_transform: Callable[[Image.Image], Tensor],
text_transform: Callable[[Union[List[str], str]], Tensor],
) -> None:
self.ann = []
for f in ann_file:
self.ann += json.load(open(f, "r"))
self.image_root = image_root
self.image_transform = image_transform
self.text_transform = text_transform
self.idx = {} # map str image_id from dataset to int ids
i = 0
for ann in self.ann:
image_id = ann["image_id"]
if image_id not in self.idx.keys():
self.idx[image_id] = i
i += 1
def __len__(self) -> int:
return len(self.ann)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]:
ann = self.ann[index]
image_path = os.path.join(self.image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.image_transform(image)
caption = self.text_transform(ann["caption"])
return image, caption, self.idx[ann["image_id"]]
class ImageToTextRetrievalDataset(Dataset):
"""
Create the dataset for Image-to-Text Retrieval task.
Args:
ann_file (List[str]): The paths to annotation json files.
image_root (str): The path to image data directory.
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
Dataset Outputs:
image (Tensor): Transformed image input tensor of shape (C, H, W).
"""
def __init__(
self,
ann_file: List[str],
image_root: str,
image_transform: Callable[[Image.Image], Tensor],
) -> None:
self.image_root = image_root
self.image_transform = image_transform
self.ann = []
self.images = [] # paths to all images in the dataset
self.image_to_text = {} # map image ids to text ids for evaluation
for f in ann_file:
self.ann += json.load(open(f, "r"))
text_id = 0
for image_id, ann in enumerate(self.ann):
self.images.append(ann["image"])
num_text = len(ann["caption"])
self.image_to_text[image_id] = list(range(text_id, text_id + num_text))
text_id += num_text
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, index: int) -> Tensor:
image_path = os.path.join(self.image_root, self.images[index])
image = Image.open(image_path).convert("RGB")
image = self.image_transform(image)
return image
class TextToImageRetrievalDataset(Dataset):
"""
Create the dataset for Text-to-Image Retrieval task.
Args:
ann_file (List[str]): The paths to annotation json files.
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
Dataset Outputs:
text (Tensor): Transformed text token input ids.
"""
def __init__(
self,
ann_file: List[str],
text_transform: Callable[[Union[List[str], str]], Tensor],
) -> None:
self.text_transform = text_transform
self.ann = []
self.text = [] # all text strings in the dataset
self.text_to_image = {} # map text ids to image ids for evaluation
for f in ann_file:
self.ann += json.load(open(f, "r"))
text_id = 0
for image_id, ann in enumerate(self.ann):
for caption in ann["caption"]:
self.text.append(caption)
self.text_to_image[text_id] = image_id
text_id += 1
def __len__(self) -> int:
return len(self.text)
def __getitem__(self, index: int) -> Tensor:
text = self.text_transform(self.text[index])
return text