import torch from torch import nn import torch.nn.functional as F from torchvision.models import resnet18, resnet from torchvision.models._meta import _IMAGENET_CATEGORIES # from yacs.config import CfgNode from dataclasses import dataclass from typing import Optional, Tuple, List from transformers.modeling_outputs import ModelOutput from transformers import PretrainedConfig, PreTrainedModel from resnet_model.configuration_resnet import ResnetConfig @dataclass class BaseModelOutputWithCls(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None class ResnetModel(PreTrainedModel): """ https://huggingface.co/docs/transformers/custom_models # 本地使用 res18_config = ResnetConfig('resnet18', True) res18_config.save_pretrained("custom-resnet") res18_config = ResnetConfig.from_pretrained("custom-resnet") res18_f = ResnetModel(res18_config) res18_f.save_pretrained("custom-resnet") res18_f = ResnetModel.from_pretrained("custom-resnet") """ config_class = ResnetConfig def __init__(self, config): super().__init__(config) # m = getattr(resnet, config.model_name)(config.pretrained) # self.model = nn.Sequential( # nn.Sequential(m.conv1, m.bn1, m.relu, m.maxpool, m.layer1), # p2 # m.layer2, # p3 # m.layer3, # p4 # m.layer4 # p5 # ) # c5 = m.inplanes self.model = getattr(resnet, config.model_name)(config.pretrained) self.model.fc = nn.Identity() c5 = self.model.inplanes self.output_channels = [c5 // 2, c5 // 4, c5 // 2, c5] out_indices = getattr(config, 'out_indices', [0, 1, 2, 3]) self.out_indices = out_indices self.output_channels = [self.output_channels[i] for i in out_indices] # def forward(self, pixel_values, **kwargs): # out = [] # nums = len(self.model) # tensor = pixel_values # for i in range(nums): # tensor = self.model[i](tensor) # out.append(tensor) # return [out[i] for i in self.out_indices] def forward(self, pixel_values, **kwargs): out = [] x = pixel_values x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) out.append(x) x = self.layer2(x) out.append(x) x = self.layer3(x) out.append(x) x = self.layer4(x) out.append(x) return [out[i] for i in self.out_indices] class ResnetModelForImageClassification(PreTrainedModel): """ https://huggingface.co/docs/transformers/custom_models # 本地使用 res18_config = ResnetConfig('resnet18', True) res18_config.save_pretrained("custom-resnet") res18_config = ResnetConfig.from_pretrained("custom-resnet") res18_cls = ResnetModelForImageClassification(res18_config) res18_cls.save_pretrained("custom-resnet") res18_cls = ResnetModelForImageClassification.from_pretrained("custom-resnet") """ config_class = ResnetConfig def __init__(self, config): super().__init__(config) self.model = getattr(resnet, config.model_name)(config.pretrained) if self.model.fc.out_features != config.num_classes: self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes) def forward(self, pixel_values, labels=None, **kwargs): logits = self.model(pixel_values) loss = torch.nn.functional.cross_entropy(logits, labels) if labels is not None else None # return {"loss": loss, "logits": logits} return BaseModelOutputWithCls(loss=loss, logits=logits)