|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from timm.models import create_model, VisionTransformer |
|
|
|
from .enable_cpe_support import enable_cpe |
|
from .input_conditioner import InputConditioner |
|
|
|
from . import extra_timm_models |
|
|
|
|
|
class RADIOModel(nn.Module): |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
input_conditioner: InputConditioner, |
|
return_summary: bool, |
|
return_spatial_features: bool, |
|
summary_idxs: Optional[torch.Tensor] = None, |
|
): |
|
super().__init__() |
|
|
|
self.model = model |
|
self.input_conditioner = input_conditioner |
|
self.return_summary = return_summary |
|
self.return_spatial_features = return_spatial_features |
|
self.summary_select_idx = -1 |
|
if summary_idxs is not None: |
|
self.register_buffer('summary_idxs', summary_idxs) |
|
else: |
|
self.summary_idxs = None |
|
|
|
@property |
|
def return_both(self): |
|
return self.return_summary and self.return_spatial_features |
|
|
|
@property |
|
def num_summary_tokens(self): |
|
patch_gen = getattr(self.model, "patch_generator", None) |
|
if patch_gen is not None: |
|
return patch_gen.num_skip |
|
elif self.model.global_pool == 'avg': |
|
return 0 |
|
return 1 |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = self.input_conditioner(x) |
|
|
|
y = self.model.forward_features(x) |
|
|
|
if isinstance(y, (list, tuple)): |
|
summary, all_feat = y |
|
elif isinstance(self.model, VisionTransformer): |
|
patch_gen = getattr(self.model, "patch_generator", None) |
|
if patch_gen is not None: |
|
summary = y[:, : patch_gen.num_cls_tokens] |
|
if self.summary_select_idx >= 0: |
|
summary = summary[:, self.summary_select_idx] |
|
elif self.summary_idxs is not None: |
|
summary = summary[:, self.summary_idxs].flatten(1) |
|
else: |
|
summary = summary.flatten(1) |
|
all_feat = y[:, patch_gen.num_skip :] |
|
elif self.model.global_pool == "avg": |
|
summary = y[:, self.model.num_prefix_tokens :].mean(dim=1) |
|
all_feat = y |
|
else: |
|
summary = y[:, 0] |
|
all_feat = y[:, 1:] |
|
else: |
|
raise ValueError("Unsupported model type") |
|
|
|
if self.return_both: |
|
return summary, all_feat |
|
elif self.return_summary: |
|
return summary |
|
return all_feat |
|
|
|
|
|
def create_model_from_args(args) -> nn.Module: |
|
in_chans = 3 |
|
if args.in_chans is not None: |
|
in_chans = args.in_chans |
|
elif args.input_size is not None: |
|
in_chans = args.input_size[0] |
|
|
|
|
|
weight_init = args.model_kwargs.pop("weight_init", "skip") |
|
|
|
model = create_model( |
|
args.model, |
|
pretrained=args.pretrained, |
|
in_chans=in_chans, |
|
num_classes=args.num_classes, |
|
drop_rate=args.drop, |
|
drop_path_rate=args.drop_path, |
|
drop_block_rate=args.drop_block, |
|
global_pool=args.gp, |
|
bn_momentum=args.bn_momentum, |
|
bn_eps=args.bn_eps, |
|
scriptable=args.torchscript, |
|
checkpoint_path=args.initial_checkpoint, |
|
weight_init=weight_init, |
|
**args.model_kwargs, |
|
) |
|
|
|
assert ( |
|
not args.cls_token_per_teacher or args.cpe_max_size is not None |
|
), "CPE must be enabled for multiple CLS tokens!" |
|
|
|
if args.cpe_max_size is not None: |
|
enable_cpe( |
|
model, |
|
args.cpe_max_size, |
|
num_cls_tokens=len(args.teachers) if args.cls_token_per_teacher else 1, |
|
register_multiple=args.register_multiple, |
|
) |
|
|
|
return model |
|
|