Spaces:
Runtime error
Runtime error
from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule | |
from .transformations import UnNest | |
from .visual_qa import CIFAR10QADataModule, ToyQADataModule | |
from argparse import Namespace | |
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor | |
def get_configs(args: Namespace) -> tuple[dict, dict]: | |
"""Get the model and feature extractor configs from the command line args. | |
Args: | |
args (Namespace): the argparse Namespace object | |
Returns: | |
a tuple containing the model and feature extractor configs | |
""" | |
if args.dataset == "MNIST": | |
# We upsample the MNIST images to 112x112, with 1 channel (grayscale) | |
# and 10 classes (0-9). We normalize the image to have a mean of 0.5 | |
# and a standard deviation of ±0.5. | |
model_cfg_args = { | |
"image_size": 112, | |
"num_channels": 1, | |
"num_labels": 10, | |
} | |
fe_cfg_args = { | |
"image_mean": [0.5], | |
"image_std": [0.5], | |
} | |
elif args.dataset.startswith("CIFAR10"): | |
if args.dataset not in ("CIFAR10", "CIFAR10_QA"): | |
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}") | |
# We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and | |
# 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the | |
# toy task. We normalize the image to have a mean of 0.5 and a standard | |
# deviation of ±0.5. | |
model_cfg_args = { | |
"image_size": 224, # fixed to 224 because pretrained models have that size | |
"num_channels": 3, | |
"num_labels": (args.grid_size**2) + 1 | |
if args.dataset == "CIFAR10_QA" | |
else 10, | |
} | |
fe_cfg_args = { | |
"image_mean": [0.5, 0.5, 0.5], | |
"image_std": [0.5, 0.5, 0.5], | |
} | |
elif args.dataset == "toy": | |
# We use an image size so that each patch contains a single color, with | |
# 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image | |
# to have a mean of 0.5 and a standard deviation of ±0.5. | |
model_cfg_args = { | |
"image_size": args.grid_size * 16, | |
"num_channels": 3, | |
"num_labels": (args.grid_size**2) + 1, | |
} | |
fe_cfg_args = { | |
"image_mean": [0.5, 0.5, 0.5], | |
"image_std": [0.5, 0.5, 0.5], | |
} | |
else: | |
raise Exception(f"Unknown dataset: {args.dataset}") | |
# Set the feature extractor's size attribute to be the same as the model's image size | |
fe_cfg_args["size"] = model_cfg_args["image_size"] | |
# Set the tensors' return type to PyTorch tensors | |
fe_cfg_args["return_tensors"] = "pt" | |
return model_cfg_args, fe_cfg_args | |
def datamodule_factory(args: Namespace) -> ImageDataModule: | |
"""A factory method for creating a datamodule based on the command line args. | |
Args: | |
args (Namespace): the argparse Namespace object | |
Returns: | |
an ImageDataModule instance | |
""" | |
# Get the model and feature extractor configs | |
model_cfg_args, fe_cfg_args = get_configs(args) | |
# Set the feature extractor class based on the provided base model name | |
if args.base_model == "ViT": | |
fe_class = ViTFeatureExtractor | |
elif args.base_model == "ConvNeXt": | |
fe_class = ConvNextFeatureExtractor | |
else: | |
raise Exception(f"Unknown base model: {args.base_model}") | |
# Create the feature extractor instance | |
if args.from_pretrained: | |
feature_extractor = fe_class.from_pretrained( | |
args.from_pretrained, **fe_cfg_args | |
) | |
else: | |
feature_extractor = fe_class(**fe_cfg_args) | |
# Un-nest the feature extractor's output | |
feature_extractor = UnNest(feature_extractor) | |
# Define the datamodule's configuration | |
dm_cfg = { | |
"feature_extractor": feature_extractor, | |
"batch_size": args.batch_size, | |
"add_noise": args.add_noise, | |
"add_rotation": args.add_rotation, | |
"add_blur": args.add_blur, | |
"num_workers": args.num_workers, | |
} | |
# Determine the dataset class based on the provided dataset name | |
if args.dataset.startswith("CIFAR10"): | |
if args.dataset == "CIFAR10": | |
dm_class = CIFAR10DataModule | |
elif args.dataset == "CIFAR10_QA": | |
dm_cfg["class_idx"] = args.class_idx | |
dm_cfg["grid_size"] = args.grid_size | |
dm_class = CIFAR10QADataModule | |
else: | |
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}") | |
elif args.dataset == "MNIST": | |
dm_class = MNISTDataModule | |
elif args.dataset == "toy": | |
dm_cfg["class_idx"] = args.class_idx | |
dm_cfg["grid_size"] = args.grid_size | |
dm_class = ToyQADataModule | |
else: | |
raise Exception(f"Unknown dataset: {args.dataset}") | |
return dm_class(**dm_cfg) | |