albef-vqa / data /retrieval_datamodule.py
ryanramos's picture
Add source code
d1b8c9b
raw
history blame
5.6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
import torch
from data.retrieval_dataset import (
ImageToTextRetrievalDataset,
RetrievalTrainingDataset,
TextToImageRetrievalDataset,
)
from data.transforms import (
ALBEFTextTransform,
testing_image_transform,
training_image_transform,
)
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, DistributedSampler
class RetrievalDataModule(LightningDataModule):
"""
The Data Module for Retrieval task.
Args:
train_files (List[str]): The paths to training json files.
test_files (List[str]): The paths to testing json files.
image_root (str): The path to image data directory.
batch_size (int): The sampling batch size.
num_workers (int): The number of workers for the distributed mode.
"""
def __init__(
self,
train_files: List[str],
test_files: List[str],
image_root: str,
batch_size: int,
num_workers: int,
) -> None:
super().__init__()
self.train_dataset = RetrievalTrainingDataset(
train_files,
image_root,
training_image_transform(),
ALBEFTextTransform(truncate=True, max_seq_len=30, add_end_token=False),
)
self.image_dataset = ImageToTextRetrievalDataset(
test_files,
image_root,
testing_image_transform(),
)
self.text_dataset = TextToImageRetrievalDataset(
test_files,
ALBEFTextTransform(
truncate=True,
pad_to_max_seq_len=True,
max_seq_len=30,
add_end_token=False,
),
)
self.batch_size = batch_size
self.num_workers = num_workers
def _get_sampler(
self,
dataset: Dataset,
shuffle: bool,
is_distributed: bool,
num_tasks: int,
global_rank: int,
) -> Optional[DistributedSampler]:
# do not return a sampler if is not in distributed mode
# a default RandomSampler is used in this case
if not is_distributed:
return None
return DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
)
def train_dataloader(
self,
is_distributed: bool = False,
num_tasks: int = 0,
global_rank: int = 0,
drop_last: bool = True,
) -> DataLoader:
"""
DataLoader Outputs:
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
text (Tensor): Tensor of shape (B, L) of text inputs.
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
idx (Tensor): Tensor of shape (B) of image identifiers.
"""
sampler = self._get_sampler(
dataset=self.train_dataset,
shuffle=True,
is_distributed=is_distributed,
num_tasks=num_tasks,
global_rank=global_rank,
)
shuffle = sampler is None
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=retrieval_train_collate_fn,
drop_last=drop_last,
)
def image_dataloader(
self,
drop_last: bool = False,
) -> DataLoader:
"""
DataLoader Outputs:
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
"""
return DataLoader(
self.image_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=None,
shuffle=False,
collate_fn=None,
drop_last=drop_last,
)
def text_dataloader(
self,
drop_last: bool = False,
) -> DataLoader:
"""
DataLoader Outputs:
text (Tensor): Tensor of shape (B, L) of text inputs.
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
"""
return DataLoader(
self.text_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=None,
shuffle=False,
collate_fn=text_collate_fn,
drop_last=drop_last,
)
def retrieval_train_collate_fn(
batch: List[Tuple[Tensor, Tensor, int]]
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
image_list = []
text_list = []
idx_list = []
for image, text, idx in batch:
image_list.append(image)
text_list.append(text)
idx_list.append(idx)
images = torch.stack(image_list, dim=0)
text = pad_sequence(text_list, batch_first=True)
text_atts = (text != 0).type(torch.long)
idx = Tensor(idx_list).type(torch.long)
return (
images,
text,
text_atts,
idx,
)
def text_collate_fn(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
text = pad_sequence(batch, batch_first=True)
text_atts = (text != 0).type(torch.long)
return text, text_atts