custom-resnet18 / modeling_resnet.py
wucng's picture
Upload model
c8588d0 verified
raw
history blame
3.72 kB
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, resnet
from torchvision.models._meta import _IMAGENET_CATEGORIES
# from yacs.config import CfgNode
from dataclasses import dataclass
from typing import Optional, Tuple, List
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig, PreTrainedModel
from resnet_model.configuration_resnet import ResnetConfig
@dataclass
class BaseModelOutputWithCls(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
class ResnetModel(PreTrainedModel):
"""
https://huggingface.co/docs/transformers/custom_models
# 本地使用
res18_config = ResnetConfig('resnet18', True)
res18_config.save_pretrained("custom-resnet")
res18_config = ResnetConfig.from_pretrained("custom-resnet")
res18_f = ResnetModel(res18_config)
res18_f.save_pretrained("custom-resnet")
res18_f = ResnetModel.from_pretrained("custom-resnet")
"""
config_class = ResnetConfig
def __init__(self, config):
super().__init__(config)
# m = getattr(resnet, config.model_name)(config.pretrained)
# self.model = nn.Sequential(
# nn.Sequential(m.conv1, m.bn1, m.relu, m.maxpool, m.layer1), # p2
# m.layer2, # p3
# m.layer3, # p4
# m.layer4 # p5
# )
# c5 = m.inplanes
self.model = getattr(resnet, config.model_name)(config.pretrained)
self.model.fc = nn.Identity()
c5 = self.model.inplanes
self.output_channels = [c5 // 2, c5 // 4, c5 // 2, c5]
out_indices = getattr(config, 'out_indices', [0, 1, 2, 3])
self.out_indices = out_indices
self.output_channels = [self.output_channels[i] for i in out_indices]
# def forward(self, pixel_values, **kwargs):
# out = []
# nums = len(self.model)
# tensor = pixel_values
# for i in range(nums):
# tensor = self.model[i](tensor)
# out.append(tensor)
# return [out[i] for i in self.out_indices]
def forward(self, pixel_values, **kwargs):
out = []
x = pixel_values
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
out.append(x)
x = self.layer2(x)
out.append(x)
x = self.layer3(x)
out.append(x)
x = self.layer4(x)
out.append(x)
return [out[i] for i in self.out_indices]
class ResnetModelForImageClassification(PreTrainedModel):
"""
https://huggingface.co/docs/transformers/custom_models
# 本地使用
res18_config = ResnetConfig('resnet18', True)
res18_config.save_pretrained("custom-resnet")
res18_config = ResnetConfig.from_pretrained("custom-resnet")
res18_cls = ResnetModelForImageClassification(res18_config)
res18_cls.save_pretrained("custom-resnet")
res18_cls = ResnetModelForImageClassification.from_pretrained("custom-resnet")
"""
config_class = ResnetConfig
def __init__(self, config):
super().__init__(config)
self.model = getattr(resnet, config.model_name)(config.pretrained)
if self.model.fc.out_features != config.num_classes:
self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes)
def forward(self, pixel_values, labels=None, **kwargs):
logits = self.model(pixel_values)
loss = torch.nn.functional.cross_entropy(logits, labels) if labels is not None else None
# return {"loss": loss, "logits": logits}
return BaseModelOutputWithCls(loss=loss, logits=logits)