din0s's picture
Add code
d4ab5ac unverified
from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule
from .transformations import UnNest
from .visual_qa import CIFAR10QADataModule, ToyQADataModule
from argparse import Namespace
from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor
def get_configs(args: Namespace) -> tuple[dict, dict]:
"""Get the model and feature extractor configs from the command line args.
Args:
args (Namespace): the argparse Namespace object
Returns:
a tuple containing the model and feature extractor configs
"""
if args.dataset == "MNIST":
# We upsample the MNIST images to 112x112, with 1 channel (grayscale)
# and 10 classes (0-9). We normalize the image to have a mean of 0.5
# and a standard deviation of ±0.5.
model_cfg_args = {
"image_size": 112,
"num_channels": 1,
"num_labels": 10,
}
fe_cfg_args = {
"image_mean": [0.5],
"image_std": [0.5],
}
elif args.dataset.startswith("CIFAR10"):
if args.dataset not in ("CIFAR10", "CIFAR10_QA"):
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
# We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and
# 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the
# toy task. We normalize the image to have a mean of 0.5 and a standard
# deviation of ±0.5.
model_cfg_args = {
"image_size": 224, # fixed to 224 because pretrained models have that size
"num_channels": 3,
"num_labels": (args.grid_size**2) + 1
if args.dataset == "CIFAR10_QA"
else 10,
}
fe_cfg_args = {
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
}
elif args.dataset == "toy":
# We use an image size so that each patch contains a single color, with
# 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image
# to have a mean of 0.5 and a standard deviation of ±0.5.
model_cfg_args = {
"image_size": args.grid_size * 16,
"num_channels": 3,
"num_labels": (args.grid_size**2) + 1,
}
fe_cfg_args = {
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
}
else:
raise Exception(f"Unknown dataset: {args.dataset}")
# Set the feature extractor's size attribute to be the same as the model's image size
fe_cfg_args["size"] = model_cfg_args["image_size"]
# Set the tensors' return type to PyTorch tensors
fe_cfg_args["return_tensors"] = "pt"
return model_cfg_args, fe_cfg_args
def datamodule_factory(args: Namespace) -> ImageDataModule:
"""A factory method for creating a datamodule based on the command line args.
Args:
args (Namespace): the argparse Namespace object
Returns:
an ImageDataModule instance
"""
# Get the model and feature extractor configs
model_cfg_args, fe_cfg_args = get_configs(args)
# Set the feature extractor class based on the provided base model name
if args.base_model == "ViT":
fe_class = ViTFeatureExtractor
elif args.base_model == "ConvNeXt":
fe_class = ConvNextFeatureExtractor
else:
raise Exception(f"Unknown base model: {args.base_model}")
# Create the feature extractor instance
if args.from_pretrained:
feature_extractor = fe_class.from_pretrained(
args.from_pretrained, **fe_cfg_args
)
else:
feature_extractor = fe_class(**fe_cfg_args)
# Un-nest the feature extractor's output
feature_extractor = UnNest(feature_extractor)
# Define the datamodule's configuration
dm_cfg = {
"feature_extractor": feature_extractor,
"batch_size": args.batch_size,
"add_noise": args.add_noise,
"add_rotation": args.add_rotation,
"add_blur": args.add_blur,
"num_workers": args.num_workers,
}
# Determine the dataset class based on the provided dataset name
if args.dataset.startswith("CIFAR10"):
if args.dataset == "CIFAR10":
dm_class = CIFAR10DataModule
elif args.dataset == "CIFAR10_QA":
dm_cfg["class_idx"] = args.class_idx
dm_cfg["grid_size"] = args.grid_size
dm_class = CIFAR10QADataModule
else:
raise Exception(f"Unknown CIFAR10 variant: {args.dataset}")
elif args.dataset == "MNIST":
dm_class = MNISTDataModule
elif args.dataset == "toy":
dm_cfg["class_idx"] = args.class_idx
dm_cfg["grid_size"] = args.grid_size
dm_class = ToyQADataModule
else:
raise Exception(f"Unknown dataset: {args.dataset}")
return dm_class(**dm_cfg)