|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torchvision.models import resnet18, resnet |
|
from torchvision.models._meta import _IMAGENET_CATEGORIES |
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, List |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
class ResnetConfig(PretrainedConfig): |
|
""" |
|
>>> https://huggingface.co/docs/transformers/custom_models |
|
>>> # 本地使用 |
|
>>> res18_config = ResnetConfig('resnet18', True) |
|
>>> res18_config.save_pretrained("custom-resnet") |
|
>>> res18_config = ResnetConfig.from_pretrained("custom-resnet") |
|
""" |
|
model_type = "resnet" |
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
model_name='resnet18', |
|
pretrained=False, |
|
input_channels: int = 3, |
|
num_classes: int = 1000, |
|
**kwargs |
|
): |
|
assert model_name in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", |
|
"resnext101_32x8d", "resnext101_64x4d", "wide_resnet50_2", "wide_resnet101_2", ] |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.input_channels = input_channels |
|
self.num_classes = num_classes |
|
id2label = {i: name for i, name in enumerate(_IMAGENET_CATEGORIES)} |
|
label2id = {name: i for i, name in enumerate(_IMAGENET_CATEGORIES)} |
|
if kwargs.get('id2label', None) is None: kwargs['id2label'] = id2label |
|
if kwargs.get('label2id', None) is None: kwargs['label2id'] = label2id |
|
if kwargs.get('out_indices', None) is None: kwargs['out_indices'] = [0, 1, 2, 3] |
|
|
|
super().__init__(**kwargs) |
|
|