File size: 4,892 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule
from .transformations import UnNest
from .visual_qa import CIFAR10QADataModule, ToyQADataModule
from argparse import Namespace
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor


def get_configs(args: Namespace) -> tuple[dict, dict]:
    """Get the model and feature extractor configs from the command line args.

    Args:
        args (Namespace): the argparse Namespace object

    Returns:
         a tuple containing the model and feature extractor configs
    """
    if args.dataset == "MNIST":
        # We upsample the MNIST images to 112x112, with 1 channel (grayscale)
        # and 10 classes (0-9). We normalize the image to have a mean of 0.5
        # and a standard deviation of ±0.5.
        model_cfg_args = {
            "image_size": 112,
            "num_channels": 1,
            "num_labels": 10,
        }
        fe_cfg_args = {
            "image_mean": [0.5],
            "image_std": [0.5],
        }
    elif args.dataset.startswith("CIFAR10"):
        if args.dataset not in ("CIFAR10", "CIFAR10_QA"):
            raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")

        # We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and
        # 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the
        # toy task.  We normalize the image to have a mean of 0.5 and a standard
        # deviation of ±0.5.
        model_cfg_args = {
            "image_size": 224,  # fixed to 224 because pretrained models have that size
            "num_channels": 3,
            "num_labels": (args.grid_size**2) + 1
            if args.dataset == "CIFAR10_QA"
            else 10,
        }
        fe_cfg_args = {
            "image_mean": [0.5, 0.5, 0.5],
            "image_std": [0.5, 0.5, 0.5],
        }
    elif args.dataset == "toy":
        # We use an image size so that each patch contains a single color, with
        # 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image
        # to have a mean of 0.5 and a standard deviation of ±0.5.
        model_cfg_args = {
            "image_size": args.grid_size * 16,
            "num_channels": 3,
            "num_labels": (args.grid_size**2) + 1,
        }
        fe_cfg_args = {
            "image_mean": [0.5, 0.5, 0.5],
            "image_std": [0.5, 0.5, 0.5],
        }
    else:
        raise Exception(f"Unknown dataset: {args.dataset}")

    # Set the feature extractor's size attribute to  be the same as the model's image size
    fe_cfg_args["size"] = model_cfg_args["image_size"]
    # Set the tensors' return type to PyTorch tensors
    fe_cfg_args["return_tensors"] = "pt"

    return model_cfg_args, fe_cfg_args


def datamodule_factory(args: Namespace) -> ImageDataModule:
    """A factory method for creating a datamodule based on the command line args.

    Args:
        args (Namespace): the argparse Namespace object

    Returns:
        an ImageDataModule instance
    """
    # Get the model and feature extractor configs
    model_cfg_args, fe_cfg_args = get_configs(args)

    # Set the feature extractor class based on the provided base model name
    if args.base_model == "ViT":
        fe_class = ViTFeatureExtractor
    elif args.base_model == "ConvNeXt":
        fe_class = ConvNextFeatureExtractor
    else:
        raise Exception(f"Unknown base model: {args.base_model}")

    # Create the feature extractor instance
    if args.from_pretrained:
        feature_extractor = fe_class.from_pretrained(
            args.from_pretrained, **fe_cfg_args
        )
    else:
        feature_extractor = fe_class(**fe_cfg_args)

    # Un-nest the feature extractor's output
    feature_extractor = UnNest(feature_extractor)

    # Define the datamodule's configuration
    dm_cfg = {
        "feature_extractor": feature_extractor,
        "batch_size": args.batch_size,
        "add_noise": args.add_noise,
        "add_rotation": args.add_rotation,
        "add_blur": args.add_blur,
        "num_workers": args.num_workers,
    }

    # Determine the dataset class based on the provided dataset name
    if args.dataset.startswith("CIFAR10"):
        if args.dataset == "CIFAR10":
            dm_class = CIFAR10DataModule
        elif args.dataset == "CIFAR10_QA":
            dm_cfg["class_idx"] = args.class_idx
            dm_cfg["grid_size"] = args.grid_size
            dm_class = CIFAR10QADataModule
        else:
            raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
    elif args.dataset == "MNIST":
        dm_class = MNISTDataModule
    elif args.dataset == "toy":
        dm_cfg["class_idx"] = args.class_idx
        dm_cfg["grid_size"] = args.grid_size
        dm_class = ToyQADataModule
    else:
        raise Exception(f"Unknown dataset: {args.dataset}")

    return dm_class(**dm_cfg)