custom-resnet18 / configuration_resnet.py
wucng's picture
Upload model
c8588d0 verified
raw
history blame contribute delete
No virus
1.88 kB
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, resnet
from torchvision.models._meta import _IMAGENET_CATEGORIES
# from yacs.config import CfgNode
from dataclasses import dataclass
from typing import Optional, Tuple, List
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig, PreTrainedModel
class ResnetConfig(PretrainedConfig):
"""
>>> https://huggingface.co/docs/transformers/custom_models
>>> # 本地使用
>>> res18_config = ResnetConfig('resnet18', True)
>>> res18_config.save_pretrained("custom-resnet")
>>> res18_config = ResnetConfig.from_pretrained("custom-resnet")
"""
model_type = "resnet"
# _name_or_path = "wucng/custom-resnet"
# architectures = ['ResnetModel', 'ResnetModelForImageClassification']
def __init__(
self,
model_name='resnet18',
pretrained=False,
input_channels: int = 3,
num_classes: int = 1000,
**kwargs
):
assert model_name in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
"resnext101_32x8d", "resnext101_64x4d", "wide_resnet50_2", "wide_resnet101_2", ]
self.model_name = model_name
self.pretrained = pretrained
self.input_channels = input_channels
self.num_classes = num_classes
id2label = {i: name for i, name in enumerate(_IMAGENET_CATEGORIES)}
label2id = {name: i for i, name in enumerate(_IMAGENET_CATEGORIES)}
if kwargs.get('id2label', None) is None: kwargs['id2label'] = id2label
if kwargs.get('label2id', None) is None: kwargs['label2id'] = label2id
if kwargs.get('out_indices', None) is None: kwargs['out_indices'] = [0, 1, 2, 3]
super().__init__(**kwargs)