File size: 2,128 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from datamodules.utils import get_configs
from transformers import (
    ConvNextConfig,
    ConvNextForImageClassification,
    PreTrainedModel,
    ViTConfig,
    ViTForImageClassification,
)

import argparse
import torch


def set_clf_head(base: PreTrainedModel, num_classes: int):
    """Set the classification head of the model in case of an output mismatch.

    Args:
        base (PreTrainedModel): the model to modify
        num_classes (int): the number of classes to use for the output layer
    """
    if base.classifier.out_features != num_classes:
        in_features = base.classifier.in_features
        base.classifier = torch.nn.Linear(in_features, num_classes)


def model_factory(
    args: argparse.Namespace,
    own_config: bool = False,
) -> PreTrainedModel:
    """A factory method for creating a HuggingFace model based on the command line args.

    Args:
        args (Namespace): the argparse Namespace object
        own_config (bool): whether to create our own model config instead of a pretrained one;
            this is recommended when the model was pre-trained on another task with a different
            amount of classes for its classifier head

    Returns:
        a PreTrainedModel instance
    """
    if args.base_model == "ViT":
        # Create a new Vision Transformer
        config_class = ViTConfig
        base_class = ViTForImageClassification
    elif args.base_model == "ConvNeXt":
        # Create a new ConvNext model
        config_class = ConvNextConfig
        base_class = ConvNextForImageClassification
    else:
        raise Exception(f"Unknown base model: {args.base_model}")

    # Get the model config
    model_cfg_args, _ = get_configs(args)
    if not own_config and args.from_pretrained:
        # Create a model from a pretrained model
        base = base_class.from_pretrained(args.from_pretrained)
        # Set the classifier head if needed
        set_clf_head(base, model_cfg_args["num_labels"])
    else:
        # Create a model based on the config
        config = config_class(**model_cfg_args)
        base = base_class(config)

    return base