File size: 1,878 Bytes
c8588d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, resnet
from torchvision.models._meta import _IMAGENET_CATEGORIES
# from yacs.config import CfgNode

from dataclasses import dataclass
from typing import Optional, Tuple, List
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig, PreTrainedModel

class ResnetConfig(PretrainedConfig):
    """
    >>> https://huggingface.co/docs/transformers/custom_models
    >>> # 本地使用
    >>> res18_config = ResnetConfig('resnet18', True)
    >>> res18_config.save_pretrained("custom-resnet")
    >>> res18_config = ResnetConfig.from_pretrained("custom-resnet")
    """
    model_type = "resnet"

    # _name_or_path = "wucng/custom-resnet"
    # architectures = ['ResnetModel', 'ResnetModelForImageClassification']

    def __init__(
            self,
            model_name='resnet18',
            pretrained=False,
            input_channels: int = 3,
            num_classes: int = 1000,
            **kwargs
    ):
        assert model_name in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
                              "resnext101_32x8d", "resnext101_64x4d", "wide_resnet50_2", "wide_resnet101_2", ]
        self.model_name = model_name
        self.pretrained = pretrained
        self.input_channels = input_channels
        self.num_classes = num_classes
        id2label = {i: name for i, name in enumerate(_IMAGENET_CATEGORIES)}
        label2id = {name: i for i, name in enumerate(_IMAGENET_CATEGORIES)}
        if kwargs.get('id2label', None) is None: kwargs['id2label'] = id2label
        if kwargs.get('label2id', None) is None: kwargs['label2id'] = label2id
        if kwargs.get('out_indices', None) is None: kwargs['out_indices'] = [0, 1, 2, 3]

        super().__init__(**kwargs)