rgrainger's picture
adding demo
3303c2f
raw
history blame contribute delete
No virus
7.38 kB
import mmpretrain
import torch
from torch import nn
from collections.abc import Iterable
from mmpretrain.models.utils.attention import MultiheadAttention
# This holds model instantiation functions by (dataset_name, model_name) tuple keys
MODEL_REGISTRY = {}
class ClsModel(nn.Module):
dataset_name: str
model_name: str
def __init__(self, dataset_name: str, model_name: str, device: str) -> None:
super().__init__()
self.dataset_name = dataset_name
self.model_name = model_name
self.device = device
def head_features(self):
pass
def num_classes(self):
pass
def forward(self, x):
"""
x: [B, 3 (RGB), H, W] image (float) [0,1]
returns: [B, C] class logits
"""
raise NotImplementedError("Forward not implemented for base class")
class TimmPretrainModelWrapper(ClsModel):
"""
Calls data preprocessing for model before entering forward
"""
def __init__(self, model: nn.Module, transform, dataset_name: str, model_name: str, device: str) -> None:
super().__init__(dataset_name, model_name, device)
self.model = model
self.transform = transform
@property
def final_linear_layer(self):
try:
testing_head = self.model.head
head = True
except:
head = False
if head:
if isinstance(self.model.head, torch.nn.Linear):
return self.model.head
else:
return self.model.head.fc
else:
return self.model.fc
def head_features(self):
return self.final_linear_layer.in_features
def num_classes(self):
return self.final_linear_layer.out_features
def head(self, feats):
return self.model.head((feats,))
def head_matrices(self):
return self.final_linear_layer.weight, self.final_linear_layer.bias
def forward(self, x, return_features=False):
x = self.transform(x)
if return_features:
feats = self.model.forward_features(x)
logits = self.model.forward_head(feats, pre_logits=True)
try:
preds = self.model.fc(logits) # convnet,
except:
preds = self.model.head(logits) # vit
return preds, logits
else:
return self.model(x)
class MMPretrainModelWrapper(ClsModel):
"""
Calls data preprocessing for model before entering forward
"""
def __init__(self, model: nn.Module, dataset_name: str, model_name: str, device: str) -> None:
super().__init__(dataset_name, model_name, device)
self.model = model
@property
def final_linear_layer(self):
return self.model.head.fc
def head_features(self):
return self.final_linear_layer.in_features
def num_classes(self):
return self.final_linear_layer.out_features
def head(self, feats):
return self.model.head((feats,))
def head_matrices(self):
return self.final_linear_layer.weight, self.final_linear_layer.bias
def forward(self, x, return_features=False):
# Data preprocessor expects 0-255 range, but we don't want to cast to proper
# uint8 because we want to maintain differentiability
x = x * 255.
x = self.model.data_preprocessor({"inputs": x})["inputs"]
if return_features:
feats = self.model.extract_feat(x)
preds = self.model.head(feats)
if isinstance(feats, Iterable):
feats = feats[-1]
return preds, feats
else:
return self.model(x)
class MMPretrainVisualTransformerWrapper(MMPretrainModelWrapper):
def __init__(self, model, dataset_name: str, model_name: str, device: str) -> None:
super().__init__(model, dataset_name, model_name, device)
attn_layers = []
def find_mha(m: nn.Module):
if isinstance(m, MultiheadAttention):
attn_layers.append(m)
model.apply(find_mha)
self.attn_layers = attn_layers
@property
def final_linear_layer(self):
return self.model.head.layers.head
def get_attention_maps(self, x):
clean_forwards = []
attention_maps = []
for attn_layer in self.attn_layers:
clean_forward = attn_layer.forward
clean_forwards.append(clean_forward)
def scaled_dot_prod_attn(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
scale = scale or query.size(-1)**0.5
if is_causal and attn_mask is not None:
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attention_maps.append(attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value
attn_layer.scaled_dot_product_attention = scaled_dot_prod_attn
ret_val = super().forward(x, False)
for attn_layer, clean_forward in zip(self.attn_layers, clean_forwards):
attn_layer.forward = clean_forward
return attention_maps
def register_mmcls_model(config_name, dataset_name, model_name,
wrapper_class=MMPretrainModelWrapper):
def instantiate_model(device):
model = mmpretrain.get_model(config_name, pretrained=True, device=device)
wrapper = wrapper_class(model, dataset_name, model_name, device)
return wrapper
MODEL_REGISTRY[(dataset_name, model_name)] = instantiate_model
def register_default_models():
register_mmcls_model("resnet18_8xb16_cifar10", "cifar10", "resnet18")
register_mmcls_model("resnet34_8xb16_cifar10", "cifar10", "resnet34")
register_mmcls_model("resnet18_8xb32_in1k", "imagenet", "resnet18")
register_mmcls_model("resnet50_8xb16_cifar100", "cifar100", "resnet50")
register_mmcls_model("resnet50_8xb32_in1k", "imagenet", "resnet50")
register_mmcls_model("densenet121_3rdparty_in1k", "imagenet", "densenet121")
register_mmcls_model("deit-small_4xb256_in1k", "imagenet", "deit_small",
wrapper_class=MMPretrainVisualTransformerWrapper)
register_mmcls_model("vit-base-p16_32xb128-mae_in1k", "imagenet", "vit_base",
wrapper_class=MMPretrainVisualTransformerWrapper)
def get_model(dataset_name, model_name, device):
"""
Returns instance of model pretrained with specified dataset
"""
return MODEL_REGISTRY[(dataset_name, model_name)](device).eval()