Spaces:
Runtime error
Runtime error
from .image_classification import CIFAR10DataModule | |
from argparse import ArgumentParser | |
from functools import partial | |
from torch import LongTensor | |
from torch.utils.data import default_collate, random_split, Sampler | |
from torchvision import transforms | |
from torchvision.datasets import VisionDataset | |
from typing import Iterator, Optional | |
import itertools | |
import random | |
import torch | |
class CIFAR10QADataModule(CIFAR10DataModule): | |
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | |
parser = parent_parser.add_argument_group("Visual QA") | |
parser.add_argument( | |
"--class_idx", | |
type=int, | |
default=3, | |
help="The class (index) to count.", | |
) | |
parser.add_argument( | |
"--grid_size", | |
type=int, | |
default=3, | |
help="The number of images per row in the grid.", | |
) | |
return parent_parser | |
def __init__( | |
self, | |
class_idx: int, | |
grid_size: int = 3, | |
feature_extractor: callable = None, | |
data_dir: str = "data/", | |
batch_size: int = 32, | |
add_noise: bool = False, | |
add_rotation: bool = False, | |
add_blur: bool = False, | |
num_workers: int = 4, | |
): | |
"""A datamodule for a modified CIFAR10 dataset that is used for Question Answering. | |
More specifically, the task is to count the number of images of a certain class in a grid. | |
Args: | |
class_idx (int): the class (index) to count | |
grid_size (int): the number of images per row in the grid | |
feature_extractor (callable): a callable feature extractor instance | |
data_dir (str): the directory to store the dataset | |
batch_size (int): the batch size for the train/val/test dataloaders | |
add_noise (bool): whether to add noise to the images | |
add_rotation (bool): whether to add rotation augmentation | |
add_blur (bool): whether to add blur augmentation | |
num_workers (int): the number of workers to use for data loading | |
""" | |
super().__init__( | |
feature_extractor, | |
data_dir, | |
(grid_size**2) * batch_size, | |
add_noise, | |
add_rotation, | |
add_blur, | |
num_workers, | |
) | |
# Store hyperparameters | |
self.class_idx = class_idx | |
self.grid_size = grid_size | |
# Save the existing transformations to be applied after creating the grid | |
self.post_transform = self.transform | |
# Set the pre-batch transformation to be the conversion from PIL to tensor | |
self.transform = transforms.PILToTensor() | |
# Specify the custom collate function and samplers | |
self.collate_fn = self.custom_collate_fn | |
self.shuffled_sampler = partial( | |
FairGridSampler, | |
class_idx=class_idx, | |
grid_size=grid_size, | |
shuffle=True, | |
) | |
self.sequential_sampler = partial( | |
FairGridSampler, | |
class_idx=class_idx, | |
grid_size=grid_size, | |
shuffle=False, | |
) | |
def custom_collate_fn(self, batch): | |
# Split the batch into groups of grid_size**2 | |
idx = range(len(batch)) | |
grids = zip(*(iter(idx),) * (self.grid_size**2)) | |
new_batch = [] | |
for grid in grids: | |
# Create a grid of images from the indices in the batch | |
img = torch.hstack( | |
[ | |
torch.dstack( | |
[batch[i][0] for i in grid[idx : idx + self.grid_size]] | |
) | |
for idx in range( | |
0, self.grid_size**2 - self.grid_size + 1, self.grid_size | |
) | |
] | |
) | |
# Apply the post transformations to the grid | |
img = self.post_transform(img) | |
# Define the target as the number of images that have the class_idx | |
targets = [batch[i][1] for i in grid] | |
target = targets.count(self.class_idx) | |
# Append grid and target to the batch | |
new_batch += [(img, target)] | |
return default_collate(new_batch) | |
class ToyQADataModule(CIFAR10QADataModule): | |
"""A datamodule for the toy dataset as described in the paper.""" | |
def prepare_data(self): | |
# No need to download anything for the toy task | |
pass | |
def setup(self, stage: Optional[str] = None): | |
img_size = 16 | |
samples = [] | |
# Generate 6000 samples based on 6 different colors | |
for r, g, b in itertools.product((0, 1), (0, 1), (0, 1)): | |
if r == g == b: | |
# We do not want black/white patches | |
continue | |
for _ in range(1000): | |
patch = torch.vstack( | |
[ | |
r * torch.ones(1, img_size, img_size), | |
g * torch.ones(1, img_size, img_size), | |
b * torch.ones(1, img_size, img_size), | |
] | |
) | |
# Assign a unique id to each color | |
target = int(f"{r}{g}{b}", 2) - 1 | |
# Append the patch and target to the samples | |
samples += [(patch, target)] | |
# Split the data to 90% train, 5% validation and 5% test | |
train_size = int(len(samples) * 0.9) | |
val_size = (len(samples) - train_size) // 2 | |
test_size = len(samples) - train_size - val_size | |
self.train_data, self.val_data, self.test_data = random_split( | |
samples, | |
[ | |
train_size, | |
val_size, | |
test_size, | |
], | |
) | |
class FairGridSampler(Sampler[int]): | |
def __init__( | |
self, | |
dataset: VisionDataset, | |
class_idx: int, | |
grid_size: int, | |
shuffle: bool = False, | |
): | |
"""A sampler that returns a grid of images from the dataset, with a uniformly random | |
amount of appearances for a specific class of interest. | |
Args: | |
dataset (VisionDataset): the dataset to sample from | |
class_idx(int): the class (index) to treat as the class of interest | |
grid_size (int): the number of images per row in the grid | |
shuffle (bool): whether to shuffle the dataset before sampling | |
""" | |
super().__init__(dataset) | |
# Save the hyperparameters | |
self.dataset = dataset | |
self.grid_size = grid_size | |
self.n_images = grid_size**2 | |
# Get the indices of the class of interest | |
self.class_indices = LongTensor( | |
[i for i, x in enumerate(dataset) if x[1] == class_idx] | |
) | |
# Get the indices of all other classes | |
self.other_indices = LongTensor( | |
[i for i, x in enumerate(dataset) if x[1] != class_idx] | |
) | |
# Fix the seed if shuffle is False | |
self.seed = None if shuffle else self._get_seed() | |
def _get_seed() -> int: | |
"""Utility function for generating a random seed.""" | |
return int(torch.empty((), dtype=torch.int64).random_().item()) | |
def __iter__(self) -> Iterator[int]: | |
# Create a torch Generator object | |
seed = self.seed if self.seed is not None else self._get_seed() | |
gen = torch.Generator() | |
gen.manual_seed(seed) | |
# Sample the batches | |
for _ in range(len(self.dataset) // self.n_images): | |
# Pick the number of instances for the class of interest | |
n_samples = torch.randint(self.n_images + 1, (), generator=gen).item() | |
# Sample the indices from the class of interest | |
idx_from_class = torch.randperm( | |
len(self.class_indices), | |
generator=gen, | |
)[:n_samples] | |
# Sample the indices from the other classes | |
idx_from_other = torch.randperm( | |
len(self.other_indices), | |
generator=gen, | |
)[: self.n_images - n_samples] | |
# Concatenate the corresponding lists of patches to form a grid | |
grid = ( | |
self.class_indices[idx_from_class].tolist() | |
+ self.other_indices[idx_from_other].tolist() | |
) | |
# Shuffle the order of the patches within the grid | |
random.shuffle(grid) | |
yield from grid | |
def __len__(self) -> int: | |
return len(self.dataset) | |