|
import os |
|
import torch |
|
|
|
from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264 |
|
from backbones.unet3d import UNet3D |
|
|
|
import utils.config |
|
|
|
def _freeze_layers_if_any(model, hparams): |
|
if len(hparams.frozen_layers) == 0: |
|
return model |
|
|
|
for (name, param) in model.named_parameters(): |
|
if any([name.startswith(to_freeze_name) for to_freeze_name in hparams.frozen_layers]): |
|
param.requires_grad = False |
|
|
|
return model |
|
|
|
def _replace_inplace_operations(model): |
|
|
|
for module in model.modules(): |
|
if hasattr(module, "inplace"): |
|
setattr(module, "inplace", False) |
|
return model |
|
|
|
def get_backbone(hparams): |
|
backbone = None |
|
|
|
in_channels = 1 + (hparams.mask == 'channel') + hparams.input_dim * hparams.coordinates |
|
|
|
if hparams.model_name.startswith('DenseNet'): |
|
if hparams.model_name == "DenseNet121": |
|
net_selection = DenseNet121 |
|
elif hparams.model_name == "DenseNet169": |
|
net_selection = DenseNet169 |
|
elif hparams.model_name == "DenseNet201": |
|
net_selection = DenseNet201 |
|
elif hparams.model_name == "DenseNet264": |
|
net_selection = DenseNet264 |
|
else: |
|
raise ValueError(f"Unknown DenseNet: {hparams.model_name}") |
|
|
|
backbone = net_selection( |
|
spatial_dims = hparams.input_dim, |
|
in_channels = in_channels, |
|
out_channels = hparams.num_classes - (hparams.loss == 'ordinal_regression'), |
|
dropout_prob = hparams.dropout, |
|
act = ("relu", {"inplace": False}) |
|
) |
|
|
|
|
|
backbone.features.transition2.pool = torch.nn.Identity() |
|
backbone.features.transition3.pool = torch.nn.Identity() |
|
|
|
elif hparams.model_name.lower().startswith("resne"): |
|
|
|
backbone = torch.hub.load('pytorch/vision:v0.10.0', hparams.model_name, pretrained=hparams.model_name.lower().endswith('-pretrained')) |
|
|
|
|
|
backbone.fc.out_features = hparams.num_classes - (hparams.loss == 'ordinal_regression') |
|
|
|
elif hparams.model_name == 'ModelsGenesis': |
|
backbone = UNet3D( |
|
in_channels=in_channels, |
|
input_size=hparams.input_size, |
|
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression') |
|
) |
|
|
|
weight_dir = os.path.join('data_sl', utils.config.globals["MODELS_GENESIS_PATH"]) |
|
|
|
checkpoint = torch.load(weight_dir,map_location=torch.device('cpu')) |
|
state_dict = checkpoint['state_dict'] |
|
unparalled_state_dict = {} |
|
|
|
for key in state_dict.keys(): |
|
unparalled_state_dict[key.replace("module.", "")] = state_dict[key] |
|
|
|
backbone.load_state_dict(unparalled_state_dict, strict=False) |
|
|
|
elif hparams.model_name == 'UNet3D': |
|
|
|
backbone = UNet3D( |
|
in_channels=in_channels, |
|
input_size=hparams.input_size, |
|
n_class=hparams.num_classes - (hparams.loss == 'ordinal_regression') |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
backbone = _replace_inplace_operations(backbone) |
|
backbone = _freeze_layers_if_any(backbone, hparams) |
|
|
|
return backbone |