|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torchvision.models import resnet18, resnet |
|
from torchvision.models._meta import _IMAGENET_CATEGORIES |
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, List |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
from .configuration_resnet import ResnetConfig |
|
|
|
@dataclass |
|
class BaseModelOutputWithCls(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
|
|
|
|
|
|
class ResnetModel(PreTrainedModel): |
|
""" |
|
>>> https://huggingface.co/docs/transformers/custom_models |
|
>>> # 本地使用 |
|
>>> res18_config = ResnetConfig('resnet18', True) |
|
>>> res18_config.save_pretrained("custom-resnet") |
|
>>> res18_config = ResnetConfig.from_pretrained("custom-resnet") |
|
|
|
>>> res18_f = ResnetModel(res18_config) |
|
>>> res18_f.save_pretrained("custom-resnet") |
|
>>> res18_f = ResnetModel.from_pretrained("custom-resnet") |
|
""" |
|
|
|
config_class = ResnetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = getattr(resnet, config.model_name)(config.pretrained) |
|
self.model.fc = nn.Identity() |
|
|
|
c5 = self.model.inplanes |
|
self.output_channels = [c5 // 2, c5 // 4, c5 // 2, c5] |
|
|
|
out_indices = getattr(config, 'out_indices', [0, 1, 2, 3]) |
|
self.out_indices = out_indices |
|
self.output_channels = [self.output_channels[i] for i in out_indices] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, pixel_values, **kwargs): |
|
out = [] |
|
x = pixel_values |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
out.append(x) |
|
x = self.layer2(x) |
|
out.append(x) |
|
x = self.layer3(x) |
|
out.append(x) |
|
x = self.layer4(x) |
|
out.append(x) |
|
|
|
return [out[i] for i in self.out_indices] |
|
|
|
|
|
class ResnetModelForImageClassification(PreTrainedModel): |
|
""" |
|
>>> https://huggingface.co/docs/transformers/custom_models |
|
>>> # 本地使用 |
|
>>> res18_config = ResnetConfig('resnet18', True) |
|
>>> res18_config.save_pretrained("custom-resnet") |
|
>>> res18_config = ResnetConfig.from_pretrained("custom-resnet") |
|
|
|
>>> res18_cls = ResnetModelForImageClassification(res18_config) |
|
>>> res18_cls.save_pretrained("custom-resnet") |
|
>>> res18_cls = ResnetModelForImageClassification.from_pretrained("custom-resnet") |
|
""" |
|
|
|
config_class = ResnetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = getattr(resnet, config.model_name)(config.pretrained) |
|
if self.model.fc.out_features != config.num_classes: |
|
self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes) |
|
|
|
def forward(self, pixel_values, labels=None, **kwargs): |
|
logits = self.model(pixel_values) |
|
loss = torch.nn.functional.cross_entropy(logits, labels) if labels is not None else None |
|
|
|
return BaseModelOutputWithCls(loss=loss, logits=logits) |
|
|
|
|