import torch from torch import nn import torch.nn.functional as F from torchvision.models import resnet18, resnet from torchvision.models._meta import _IMAGENET_CATEGORIES # from yacs.config import CfgNode from dataclasses import dataclass from typing import Optional, Tuple, List from transformers.modeling_outputs import ModelOutput from transformers import PretrainedConfig, PreTrainedModel class ResnetConfig(PretrainedConfig): """ >>> https://huggingface.co/docs/transformers/custom_models >>> # 本地使用 >>> res18_config = ResnetConfig('resnet18', True) >>> res18_config.save_pretrained("custom-resnet") >>> res18_config = ResnetConfig.from_pretrained("custom-resnet") """ model_type = "resnet" # _name_or_path = "wucng/custom-resnet" # architectures = ['ResnetModel', 'ResnetModelForImageClassification'] def __init__( self, model_name='resnet18', pretrained=False, input_channels: int = 3, num_classes: int = 1000, **kwargs ): assert model_name in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "resnext101_64x4d", "wide_resnet50_2", "wide_resnet101_2", ] self.model_name = model_name self.pretrained = pretrained self.input_channels = input_channels self.num_classes = num_classes id2label = {i: name for i, name in enumerate(_IMAGENET_CATEGORIES)} label2id = {name: i for i, name in enumerate(_IMAGENET_CATEGORIES)} if kwargs.get('id2label', None) is None: kwargs['id2label'] = id2label if kwargs.get('label2id', None) is None: kwargs['label2id'] = label2id if kwargs.get('out_indices', None) is None: kwargs['out_indices'] = [0, 1, 2, 3] super().__init__(**kwargs)