Spaces:
Sleeping
Sleeping
File size: 5,799 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import mmpretrain
import torch
from torch import nn
from collections.abc import Iterable
from mmpretrain.models.utils.attention import MultiheadAttention
# This holds model instantiation functions by (dataset_name, model_name) tuple keys
MODEL_REGISTRY = {}
class ClsModel(nn.Module):
dataset_name: str
model_name: str
def __init__(self, dataset_name: str, model_name: str, device: str) -> None:
super().__init__()
self.dataset_name = dataset_name
self.model_name = model_name
self.device = device
def head_features(self):
pass
def num_classes(self):
pass
def forward(self, x):
"""
x: [B, 3 (RGB), H, W] image (float) [0,1]
returns: [B, C] class logits
"""
raise NotImplementedError("Forward not implemented for base class")
class MMPretrainModelWrapper(ClsModel):
"""
Calls data preprocessing for model before entering forward
"""
def __init__(self, model: nn.Module, dataset_name: str, model_name: str, device: str) -> None:
super().__init__(dataset_name, model_name, device)
self.model = model
@property
def final_linear_layer(self):
return self.model.head.fc
def head_features(self):
return self.final_linear_layer.in_features
def num_classes(self):
return self.final_linear_layer.out_features
def head(self, feats):
return self.model.head((feats,))
def head_matrices(self):
return self.final_linear_layer.weight, self.final_linear_layer.bias
def forward(self, x, return_features=False):
# Data preprocessor expects 0-255 range, but we don't want to cast to proper
# uint8 because we want to maintain differentiability
x = x * 255.
x = self.model.data_preprocessor({"inputs": x})["inputs"]
if return_features:
feats = self.model.extract_feat(x)
preds = self.model.head(feats)
if isinstance(feats, Iterable):
feats = feats[-1]
return preds, feats
else:
return self.model(x)
class MMPretrainVisualTransformerWrapper(MMPretrainModelWrapper):
def __init__(self, model, dataset_name: str, model_name: str, device: str) -> None:
super().__init__(model, dataset_name, model_name, device)
attn_layers = []
def find_mha(m: nn.Module):
if isinstance(m, MultiheadAttention):
attn_layers.append(m)
model.apply(find_mha)
self.attn_layers = attn_layers
@property
def final_linear_layer(self):
return self.model.head.layers.head
def get_attention_maps(self, x):
clean_forwards = []
attention_maps = []
for attn_layer in self.attn_layers:
clean_forward = attn_layer.forward
clean_forwards.append(clean_forward)
def scaled_dot_prod_attn(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
scale = scale or query.size(-1)**0.5
if is_causal and attn_mask is not None:
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attention_maps.append(attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value
attn_layer.scaled_dot_product_attention = scaled_dot_prod_attn
ret_val = super().forward(x, False)
for attn_layer, clean_forward in zip(self.attn_layers, clean_forwards):
attn_layer.forward = clean_forward
return attention_maps
def register_mmcls_model(config_name, dataset_name, model_name,
wrapper_class=MMPretrainModelWrapper):
def instantiate_model(device):
model = mmpretrain.get_model(config_name, pretrained=True, device=device)
wrapper = wrapper_class(model, dataset_name, model_name, device)
return wrapper
MODEL_REGISTRY[(dataset_name, model_name)] = instantiate_model
def register_default_models():
register_mmcls_model("resnet18_8xb16_cifar10", "cifar10", "resnet18")
register_mmcls_model("resnet34_8xb16_cifar10", "cifar10", "resnet34")
register_mmcls_model("resnet18_8xb32_in1k", "imagenet", "resnet18")
register_mmcls_model("resnet50_8xb16_cifar100", "cifar100", "resnet50")
register_mmcls_model("resnet50_8xb32_in1k", "imagenet", "resnet50")
register_mmcls_model("densenet121_3rdparty_in1k", "imagenet", "densenet121")
register_mmcls_model("deit-small_4xb256_in1k", "imagenet", "deit_small",
wrapper_class=MMPretrainVisualTransformerWrapper)
register_mmcls_model("vit-base-p16_32xb128-mae_in1k", "imagenet", "vit_base",
wrapper_class=MMPretrainVisualTransformerWrapper)
def get_model(dataset_name, model_name, device):
"""
Returns instance of model pretrained with specified dataset
"""
return MODEL_REGISTRY[(dataset_name, model_name)](device).eval() |