Spaces:
Running
Running
""" | |
This code was adapted from: https://github.com/rgeirhos/texture-vs-shape | |
""" | |
import os | |
import sys | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models | |
from torch.utils import model_zoo | |
from .normalizer import Normalizer | |
def load_model(model_name): | |
model_urls = { | |
'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar', | |
'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar', | |
'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar', | |
'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar', | |
'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar', | |
} | |
if "resnet50" in model_name: | |
#print("Using the ResNet50 architecture.") | |
model = torchvision.models.resnet50(pretrained=False) | |
#model = torch.nn.DataParallel(model) # .cuda() | |
# fake DataParallel structrue | |
model = torch.nn.Sequential(OrderedDict([('module', model)])) | |
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) | |
elif "vgg16" in model_name: | |
#print("Using the VGG-16 architecture.") | |
# download model from URL manually and save to desired location | |
filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar" | |
assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)" | |
model = torchvision.models.vgg16(pretrained=False) | |
model.features = torch.nn.DataParallel(model.features) | |
model.cuda() | |
checkpoint = torch.load(filepath, map_location=torch.device('cpu')) | |
elif "alexnet" in model_name: | |
#print("Using the AlexNet architecture.") | |
model = torchvision.models.alexnet(pretrained=False) | |
model.features = torch.nn.DataParallel(model.features) | |
model.cuda() | |
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu')) | |
else: | |
raise ValueError("unknown model architecture.") | |
model.load_state_dict(checkpoint["state_dict"]) | |
return model | |
# --- DeepGaze Adaptation ---- | |
class RGBShapeNetA(nn.Sequential): | |
def __init__(self): | |
super(RGBShapeNetA, self).__init__() | |
self.shapenet = load_model("resnet50_trained_on_SIN") | |
self.normalizer = Normalizer() | |
super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet) | |
class RGBShapeNetB(nn.Sequential): | |
def __init__(self): | |
super(RGBShapeNetB, self).__init__() | |
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN") | |
self.normalizer = Normalizer() | |
super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet) | |
class RGBShapeNetC(nn.Sequential): | |
def __init__(self): | |
super(RGBShapeNetC, self).__init__() | |
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN") | |
self.normalizer = Normalizer() | |
super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet) | |