E-RADIO / radio_model.py
gheinrich's picture
Upload model
f6d64da verified
raw
history blame
4.25 kB
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Optional
import torch
from torch import nn
from timm.models import create_model, VisionTransformer
from .enable_cpe_support import enable_cpe
from .input_conditioner import InputConditioner
# Register extra models
from . import extra_timm_models
class RADIOModel(nn.Module):
def __init__(
self,
model: nn.Module,
input_conditioner: InputConditioner,
return_summary: bool,
return_spatial_features: bool,
summary_idxs: Optional[torch.Tensor] = None,
):
super().__init__()
self.model = model
self.input_conditioner = input_conditioner
self.return_summary = return_summary
self.return_spatial_features = return_spatial_features
self.summary_select_idx = -1
if summary_idxs is not None:
self.register_buffer('summary_idxs', summary_idxs)
else:
self.summary_idxs = None
@property
def return_both(self):
return self.return_summary and self.return_spatial_features
@property
def num_summary_tokens(self):
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
return patch_gen.num_skip
elif self.model.global_pool == 'avg':
return 0
return 1
def forward(self, x: torch.Tensor):
x = self.input_conditioner(x)
y = self.model.forward_features(x)
if isinstance(y, (list, tuple)):
summary, all_feat = y
elif isinstance(self.model, VisionTransformer):
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_select_idx >= 0:
summary = summary[:, self.summary_select_idx]
elif self.summary_idxs is not None:
summary = summary[:, self.summary_idxs].flatten(1)
else:
summary = summary.flatten(1)
all_feat = y[:, patch_gen.num_skip :]
elif self.model.global_pool == "avg":
summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
all_feat = y
else:
summary = y[:, 0]
all_feat = y[:, 1:]
else:
raise ValueError("Unsupported model type")
if self.return_both:
return summary, all_feat
elif self.return_summary:
return summary
return all_feat
def create_model_from_args(args) -> nn.Module:
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
# Skip weight initialization unless it's explicitly requested.
weight_init = args.model_kwargs.pop("weight_init", "skip")
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
weight_init=weight_init,
**args.model_kwargs,
)
assert (
not args.cls_token_per_teacher or args.cpe_max_size is not None
), "CPE must be enabled for multiple CLS tokens!"
if args.cpe_max_size is not None:
enable_cpe(
model,
args.cpe_max_size,
num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1,
register_multiple=args.register_multiple,
)
return model