mbar0075's picture
Testing Commit
c9baa67
"""
This code was adapted from: https://github.com/rgeirhos/texture-vs-shape
"""
import os
import sys
from collections import OrderedDict
import torch
import torch.nn as nn
import torchvision
import torchvision.models
from torch.utils import model_zoo
from .normalizer import Normalizer
def load_model(model_name):
model_urls = {
'resnet50_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar',
'resnet50_trained_on_SIN_and_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar',
'resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar',
'vgg16_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar',
'alexnet_trained_on_SIN': 'https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar',
}
if "resnet50" in model_name:
#print("Using the ResNet50 architecture.")
model = torchvision.models.resnet50(pretrained=False)
#model = torch.nn.DataParallel(model) # .cuda()
# fake DataParallel structrue
model = torch.nn.Sequential(OrderedDict([('module', model)]))
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
elif "vgg16" in model_name:
#print("Using the VGG-16 architecture.")
# download model from URL manually and save to desired location
filepath = "./vgg16_train_60_epochs_lr0.01-6c6fcc9f.pth.tar"
assert os.path.exists(filepath), "Please download the VGG model yourself from the following link and save it locally: https://drive.google.com/drive/folders/1A0vUWyU6fTuc-xWgwQQeBvzbwi6geYQK (too large to be downloaded automatically like the other models)"
model = torchvision.models.vgg16(pretrained=False)
model.features = torch.nn.DataParallel(model.features)
model.cuda()
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
elif "alexnet" in model_name:
#print("Using the AlexNet architecture.")
model = torchvision.models.alexnet(pretrained=False)
model.features = torch.nn.DataParallel(model.features)
model.cuda()
checkpoint = model_zoo.load_url(model_urls[model_name], map_location=torch.device('cpu'))
else:
raise ValueError("unknown model architecture.")
model.load_state_dict(checkpoint["state_dict"])
return model
# --- DeepGaze Adaptation ----
class RGBShapeNetA(nn.Sequential):
def __init__(self):
super(RGBShapeNetA, self).__init__()
self.shapenet = load_model("resnet50_trained_on_SIN")
self.normalizer = Normalizer()
super(RGBShapeNetA, self).__init__(self.normalizer, self.shapenet)
class RGBShapeNetB(nn.Sequential):
def __init__(self):
super(RGBShapeNetB, self).__init__()
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN")
self.normalizer = Normalizer()
super(RGBShapeNetB, self).__init__(self.normalizer, self.shapenet)
class RGBShapeNetC(nn.Sequential):
def __init__(self):
super(RGBShapeNetC, self).__init__()
self.shapenet = load_model("resnet50_trained_on_SIN_and_IN_then_finetuned_on_IN")
self.normalizer = Normalizer()
super(RGBShapeNetC, self).__init__(self.normalizer, self.shapenet)