din0s's picture
Add code
d4ab5ac unverified
raw
history blame
No virus
2.13 kB
from datamodules.utils import get_configs
from transformers import (
ConvNextConfig,
ConvNextForImageClassification,
PreTrainedModel,
ViTConfig,
ViTForImageClassification,
)
import argparse
import torch
def set_clf_head(base: PreTrainedModel, num_classes: int):
"""Set the classification head of the model in case of an output mismatch.
Args:
base (PreTrainedModel): the model to modify
num_classes (int): the number of classes to use for the output layer
"""
if base.classifier.out_features != num_classes:
in_features = base.classifier.in_features
base.classifier = torch.nn.Linear(in_features, num_classes)
def model_factory(
args: argparse.Namespace,
own_config: bool = False,
) -> PreTrainedModel:
"""A factory method for creating a HuggingFace model based on the command line args.
Args:
args (Namespace): the argparse Namespace object
own_config (bool): whether to create our own model config instead of a pretrained one;
this is recommended when the model was pre-trained on another task with a different
amount of classes for its classifier head
Returns:
a PreTrainedModel instance
"""
if args.base_model == "ViT":
# Create a new Vision Transformer
config_class = ViTConfig
base_class = ViTForImageClassification
elif args.base_model == "ConvNeXt":
# Create a new ConvNext model
config_class = ConvNextConfig
base_class = ConvNextForImageClassification
else:
raise Exception(f"Unknown base model: {args.base_model}")
# Get the model config
model_cfg_args, _ = get_configs(args)
if not own_config and args.from_pretrained:
# Create a model from a pretrained model
base = base_class.from_pretrained(args.from_pretrained)
# Set the classifier head if needed
set_clf_head(base, model_cfg_args["num_labels"])
else:
# Create a model based on the config
config = config_class(**model_cfg_args)
base = base_class(config)
return base