Spaces:
Runtime error
Runtime error
import os | |
import random | |
from typing import Callable, Dict, List | |
import albumentations as alb | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from virtex.data.readers import LmdbReader | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.data import transforms as T | |
class CaptioningDataset(Dataset): | |
r""" | |
A dataset which provides image-caption (forward and backward) pairs from | |
a serialized LMDB file (COCO Captions in this codebase). This is used for | |
pretraining tasks which use captions - bicaptioning, forward captioning and | |
token classification. | |
This dataset also supports training on a randomly selected subset of the | |
full dataset. | |
Parameters | |
---------- | |
data_root: str, optional (default = "datasets/coco") | |
Path to the dataset root directory. This must contain the serialized | |
LMDB files (for COCO ``train2017`` and ``val2017`` splits). | |
split: str, optional (default = "train") | |
Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. | |
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer | |
A tokenizer which has the mapping between word tokens and their | |
integer IDs. | |
image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) | |
A list of transformations, from either `albumentations | |
<https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms` | |
to be applied on the image. | |
max_caption_length: int, optional (default = 30) | |
Maximum number of tokens to keep in output caption tokens. Extra tokens | |
will be trimmed from the right end of the token list. | |
use_single_caption: bool, optional (default = False) | |
COCO Captions provides five captions per image. If this is True, only | |
one fixed caption per image is use fo training (used for an ablation). | |
percentage: float, optional (default = 100.0) | |
Randomly sample this much percentage of full dataset for training. | |
""" | |
def __init__( | |
self, | |
data_root: str, | |
split: str, | |
tokenizer: SentencePieceBPETokenizer, | |
image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, | |
max_caption_length: int = 30, | |
use_single_caption: bool = False, | |
percentage: float = 100.0, | |
): | |
lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb") | |
self.reader = LmdbReader(lmdb_path, percentage=percentage) | |
self.image_transform = image_transform | |
self.caption_transform = alb.Compose( | |
[ | |
T.NormalizeCaption(), | |
T.TokenizeCaption(tokenizer), | |
T.TruncateCaptionTokens(max_caption_length), | |
] | |
) | |
self.use_single_caption = use_single_caption | |
self.padding_idx = tokenizer.token_to_id("<unk>") | |
def __len__(self): | |
return len(self.reader) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
image_id, image, captions = self.reader[idx] | |
# Pick a random caption or first caption and process (transform) it. | |
if self.use_single_caption: | |
caption = captions[0] | |
else: | |
caption = random.choice(captions) | |
# Transform image-caption pair and convert image from HWC to CHW format. | |
# Pass in caption to image_transform due to paired horizontal flip. | |
# Caption won't be tokenized/processed here. | |
image_caption = self.image_transform(image=image, caption=caption) | |
image, caption = image_caption["image"], image_caption["caption"] | |
image = np.transpose(image, (2, 0, 1)) | |
caption_tokens = self.caption_transform(caption=caption)["caption"] | |
return { | |
"image_id": torch.tensor(image_id, dtype=torch.long), | |
"image": torch.tensor(image, dtype=torch.float), | |
"caption_tokens": torch.tensor(caption_tokens, dtype=torch.long), | |
"noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0), | |
"caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long), | |
} | |
def collate_fn( | |
self, data: List[Dict[str, torch.Tensor]] | |
) -> Dict[str, torch.Tensor]: | |
# Pad `caption_tokens` and `masked_labels` up to this length. | |
caption_tokens = torch.nn.utils.rnn.pad_sequence( | |
[d["caption_tokens"] for d in data], | |
batch_first=True, | |
padding_value=self.padding_idx, | |
) | |
noitpac_tokens = torch.nn.utils.rnn.pad_sequence( | |
[d["noitpac_tokens"] for d in data], | |
batch_first=True, | |
padding_value=self.padding_idx, | |
) | |
return { | |
"image_id": torch.stack([d["image_id"] for d in data], dim=0), | |
"image": torch.stack([d["image"] for d in data], dim=0), | |
"caption_tokens": caption_tokens, | |
"noitpac_tokens": noitpac_tokens, | |
"caption_lengths": torch.stack([d["caption_lengths"] for d in data]), | |
} | |