File size: 8,510 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from .image_classification import CIFAR10DataModule
from argparse import ArgumentParser
from functools import partial
from torch import LongTensor
from torch.utils.data import default_collate, random_split, Sampler
from torchvision import transforms
from torchvision.datasets import VisionDataset
from typing import Iterator, Optional

import itertools
import random
import torch


class CIFAR10QADataModule(CIFAR10DataModule):
    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parser = parent_parser.add_argument_group("Visual QA")
        parser.add_argument(
            "--class_idx",
            type=int,
            default=3,
            help="The class (index) to count.",
        )
        parser.add_argument(
            "--grid_size",
            type=int,
            default=3,
            help="The number of images per row in the grid.",
        )
        return parent_parser

    def __init__(
        self,
        class_idx: int,
        grid_size: int = 3,
        feature_extractor: callable = None,
        data_dir: str = "data/",
        batch_size: int = 32,
        add_noise: bool = False,
        add_rotation: bool = False,
        add_blur: bool = False,
        num_workers: int = 4,
    ):
        """A datamodule for a modified CIFAR10 dataset that is used for Question Answering.
        More specifically, the task is to count the number of images of a certain class in a grid.

        Args:
            class_idx (int): the class (index) to count
            grid_size (int): the number of images per row in the grid
            feature_extractor (callable): a callable feature extractor instance
            data_dir (str): the directory to store the dataset
            batch_size (int): the batch size for the train/val/test dataloaders
            add_noise (bool): whether to add noise to the images
            add_rotation (bool): whether to add rotation augmentation
            add_blur (bool): whether to add blur augmentation
            num_workers (int): the number of workers to use for data loading
        """
        super().__init__(
            feature_extractor,
            data_dir,
            (grid_size**2) * batch_size,
            add_noise,
            add_rotation,
            add_blur,
            num_workers,
        )

        # Store hyperparameters
        self.class_idx = class_idx
        self.grid_size = grid_size

        # Save the existing transformations to be applied after creating the grid
        self.post_transform = self.transform
        # Set the pre-batch transformation to be the conversion from PIL to tensor
        self.transform = transforms.PILToTensor()

        # Specify the custom collate function and samplers
        self.collate_fn = self.custom_collate_fn
        self.shuffled_sampler = partial(
            FairGridSampler,
            class_idx=class_idx,
            grid_size=grid_size,
            shuffle=True,
        )
        self.sequential_sampler = partial(
            FairGridSampler,
            class_idx=class_idx,
            grid_size=grid_size,
            shuffle=False,
        )

    def custom_collate_fn(self, batch):
        # Split the batch into groups of grid_size**2
        idx = range(len(batch))
        grids = zip(*(iter(idx),) * (self.grid_size**2))

        new_batch = []
        for grid in grids:
            # Create a grid of images from the indices in the batch
            img = torch.hstack(
                [
                    torch.dstack(
                        [batch[i][0] for i in grid[idx : idx + self.grid_size]]
                    )
                    for idx in range(
                        0, self.grid_size**2 - self.grid_size + 1, self.grid_size
                    )
                ]
            )
            # Apply the post transformations to the grid
            img = self.post_transform(img)
            # Define the target as the number of images that have the class_idx
            targets = [batch[i][1] for i in grid]
            target = targets.count(self.class_idx)
            # Append grid and target to the batch
            new_batch += [(img, target)]

        return default_collate(new_batch)


class ToyQADataModule(CIFAR10QADataModule):
    """A datamodule for the toy dataset as described in the paper."""

    def prepare_data(self):
        # No need to download anything for the toy task
        pass

    def setup(self, stage: Optional[str] = None):
        img_size = 16

        samples = []
        # Generate 6000 samples based on 6 different colors
        for r, g, b in itertools.product((0, 1), (0, 1), (0, 1)):
            if r == g == b:
                # We do not want black/white patches
                continue

            for _ in range(1000):
                patch = torch.vstack(
                    [
                        r * torch.ones(1, img_size, img_size),
                        g * torch.ones(1, img_size, img_size),
                        b * torch.ones(1, img_size, img_size),
                    ]
                )

                # Assign a unique id to each color
                target = int(f"{r}{g}{b}", 2) - 1
                # Append the patch and target to the samples
                samples += [(patch, target)]

        # Split the data to 90% train, 5% validation and 5% test
        train_size = int(len(samples) * 0.9)
        val_size = (len(samples) - train_size) // 2
        test_size = len(samples) - train_size - val_size
        self.train_data, self.val_data, self.test_data = random_split(
            samples,
            [
                train_size,
                val_size,
                test_size,
            ],
        )


class FairGridSampler(Sampler[int]):
    def __init__(
        self,
        dataset: VisionDataset,
        class_idx: int,
        grid_size: int,
        shuffle: bool = False,
    ):
        """A sampler that returns a grid of images from the dataset, with a uniformly random
         amount of appearances for a specific class of interest.

        Args:
            dataset (VisionDataset): the dataset to sample from
            class_idx(int): the class (index) to treat as the class of interest
            grid_size (int): the number of images per row in the grid
            shuffle (bool): whether to shuffle the dataset before sampling
        """
        super().__init__(dataset)

        # Save the hyperparameters
        self.dataset = dataset
        self.grid_size = grid_size
        self.n_images = grid_size**2

        # Get the indices of the class of interest
        self.class_indices = LongTensor(
            [i for i, x in enumerate(dataset) if x[1] == class_idx]
        )
        # Get the indices of all other classes
        self.other_indices = LongTensor(
            [i for i, x in enumerate(dataset) if x[1] != class_idx]
        )

        # Fix the seed if shuffle is False
        self.seed = None if shuffle else self._get_seed()

    @staticmethod
    def _get_seed() -> int:
        """Utility function for generating a random seed."""
        return int(torch.empty((), dtype=torch.int64).random_().item())

    def __iter__(self) -> Iterator[int]:
        # Create a torch Generator object
        seed = self.seed if self.seed is not None else self._get_seed()
        gen = torch.Generator()
        gen.manual_seed(seed)

        # Sample the batches
        for _ in range(len(self.dataset) // self.n_images):
            # Pick the number of instances for the class of interest
            n_samples = torch.randint(self.n_images + 1, (), generator=gen).item()

            # Sample the indices from the class of interest
            idx_from_class = torch.randperm(
                len(self.class_indices),
                generator=gen,
            )[:n_samples]
            # Sample the indices from the other classes
            idx_from_other = torch.randperm(
                len(self.other_indices),
                generator=gen,
            )[: self.n_images - n_samples]

            # Concatenate the corresponding lists of patches to form a grid
            grid = (
                self.class_indices[idx_from_class].tolist()
                + self.other_indices[idx_from_other].tolist()
            )

            # Shuffle the order of the patches within the grid
            random.shuffle(grid)
            yield from grid

    def __len__(self) -> int:
        return len(self.dataset)