import os
import random
from typing import Callable, Dict, List
import albumentations as alb
import numpy as np
import torch
from torch.utils.data import Dataset
from virtex.data.readers import LmdbReader
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
class CaptioningDataset(Dataset):
r"""
A dataset which provides image-caption (forward and backward) pairs from
a serialized LMDB file (COCO Captions in this codebase). This is used for
pretraining tasks which use captions - bicaptioning, forward captioning and
token classification.
This dataset also supports training on a randomly selected subset of the
full dataset.
Parameters
----------
data_root: str, optional (default = "datasets/coco")
Path to the dataset root directory. This must contain the serialized
LMDB files (for COCO ``train2017`` and ``val2017`` splits).
split: str, optional (default = "train")
Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``.
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer
A tokenizer which has the mapping between word tokens and their
integer IDs.
image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
A list of transformations, from either `albumentations
`_ or :mod:`virtex.data.transforms`
to be applied on the image.
max_caption_length: int, optional (default = 30)
Maximum number of tokens to keep in output caption tokens. Extra tokens
will be trimmed from the right end of the token list.
use_single_caption: bool, optional (default = False)
COCO Captions provides five captions per image. If this is True, only
one fixed caption per image is use fo training (used for an ablation).
percentage: float, optional (default = 100.0)
Randomly sample this much percentage of full dataset for training.
"""
def __init__(
self,
data_root: str,
split: str,
tokenizer: SentencePieceBPETokenizer,
image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
max_caption_length: int = 30,
use_single_caption: bool = False,
percentage: float = 100.0,
):
lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb")
self.reader = LmdbReader(lmdb_path, percentage=percentage)
self.image_transform = image_transform
self.caption_transform = alb.Compose(
[
T.NormalizeCaption(),
T.TokenizeCaption(tokenizer),
T.TruncateCaptionTokens(max_caption_length),
]
)
self.use_single_caption = use_single_caption
self.padding_idx = tokenizer.token_to_id("")
def __len__(self):
return len(self.reader)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
image_id, image, captions = self.reader[idx]
# Pick a random caption or first caption and process (transform) it.
if self.use_single_caption:
caption = captions[0]
else:
caption = random.choice(captions)
# Transform image-caption pair and convert image from HWC to CHW format.
# Pass in caption to image_transform due to paired horizontal flip.
# Caption won't be tokenized/processed here.
image_caption = self.image_transform(image=image, caption=caption)
image, caption = image_caption["image"], image_caption["caption"]
image = np.transpose(image, (2, 0, 1))
caption_tokens = self.caption_transform(caption=caption)["caption"]
return {
"image_id": torch.tensor(image_id, dtype=torch.long),
"image": torch.tensor(image, dtype=torch.float),
"caption_tokens": torch.tensor(caption_tokens, dtype=torch.long),
"noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0),
"caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long),
}
def collate_fn(
self, data: List[Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
# Pad `caption_tokens` and `masked_labels` up to this length.
caption_tokens = torch.nn.utils.rnn.pad_sequence(
[d["caption_tokens"] for d in data],
batch_first=True,
padding_value=self.padding_idx,
)
noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
[d["noitpac_tokens"] for d in data],
batch_first=True,
padding_value=self.padding_idx,
)
return {
"image_id": torch.stack([d["image_id"] for d in data], dim=0),
"image": torch.stack([d["image"] for d in data], dim=0),
"caption_tokens": caption_tokens,
"noitpac_tokens": noitpac_tokens,
"caption_lengths": torch.stack([d["caption_lengths"] for d in data]),
}