Paul Engstler
Initial commit
92f0e98
raw
history blame
3.54 kB
import os
import torch
from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264
from backbones.unet3d import UNet3D
import utils.config
def _freeze_layers_if_any(model, hparams):
if len(hparams.frozen_layers) == 0:
return model
for (name, param) in model.named_parameters():
if any([name.startswith(to_freeze_name) for to_freeze_name in hparams.frozen_layers]):
param.requires_grad = False
return model
def _replace_inplace_operations(model):
# Grad-CAM compatibility
for module in model.modules():
if hasattr(module, "inplace"):
setattr(module, "inplace", False)
return model
def get_backbone(hparams):
backbone = None
in_channels = 1 + (hparams.mask == 'channel') + hparams.input_dim * hparams.coordinates
if hparams.model_name.startswith('DenseNet'):
if hparams.model_name == "DenseNet121":
net_selection = DenseNet121
elif hparams.model_name == "DenseNet169":
net_selection = DenseNet169
elif hparams.model_name == "DenseNet201":
net_selection = DenseNet201
elif hparams.model_name == "DenseNet264":
net_selection = DenseNet264
else:
raise ValueError(f"Unknown DenseNet: {hparams.model_name}")
backbone = net_selection(
spatial_dims = hparams.input_dim,
in_channels = in_channels,
out_channels = hparams.num_classes - (hparams.loss == 'ordinal_regression'),
dropout_prob = hparams.dropout,
act = ("relu", {"inplace": False}) # inplace has to be set to False to enable use of Grad-CAM
)
# ensure activation maps are not shrunk too much
backbone.features.transition2.pool = torch.nn.Identity()
backbone.features.transition3.pool = torch.nn.Identity()
elif hparams.model_name.lower().startswith("resne"):
# if you use pre-trained models, please add "pretrained_resnet" to the transforms hyperparameter
backbone = torch.hub.load('pytorch/vision:v0.10.0', hparams.model_name, pretrained=hparams.model_name.lower().endswith('-pretrained'))
# reset final fully connected layer to expected number of classes
backbone.fc.out_features = hparams.num_classes - (hparams.loss == 'ordinal_regression')
elif hparams.model_name == 'ModelsGenesis':
backbone = UNet3D(
in_channels=in_channels,
input_size=hparams.input_size,
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
)
weight_dir = os.path.join('data_sl', utils.config.globals["MODELS_GENESIS_PATH"])
checkpoint = torch.load(weight_dir,map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
unparalled_state_dict = {}
for key in state_dict.keys():
unparalled_state_dict[key.replace("module.", "")] = state_dict[key]
backbone.load_state_dict(unparalled_state_dict, strict=False)
elif hparams.model_name == 'UNet3D':
# this is the architecture of Models Genesis minus the pretraining
backbone = UNet3D(
in_channels=in_channels,
input_size=hparams.input_size,
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression')
)
else:
raise NotImplementedError
backbone = _replace_inplace_operations(backbone)
backbone = _freeze_layers_if_any(backbone, hparams)
return backbone