File size: 7,376 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3303c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import mmpretrain
import torch
from torch import nn
from collections.abc import Iterable
from mmpretrain.models.utils.attention import MultiheadAttention
# This holds model instantiation functions by (dataset_name, model_name) tuple keys
MODEL_REGISTRY = {}

class ClsModel(nn.Module):
    dataset_name: str
    model_name: str

    def __init__(self, dataset_name: str, model_name: str, device: str) -> None:
        super().__init__()
        self.dataset_name = dataset_name
        self.model_name = model_name
        self.device = device

    def head_features(self):
        pass

    def num_classes(self):
        pass

    def forward(self, x):
        """
        x: [B, 3 (RGB), H, W] image (float) [0,1]

        returns: [B, C] class logits
        """

        raise NotImplementedError("Forward not implemented for base class")

class TimmPretrainModelWrapper(ClsModel):
    """
    Calls data preprocessing for model before entering forward
    """
    def __init__(self, model: nn.Module, transform, dataset_name: str, model_name: str, device: str) -> None:
        super().__init__(dataset_name, model_name, device)
        self.model = model
        self.transform = transform

    @property
    def final_linear_layer(self):
        try:
            testing_head = self.model.head 
            head = True
        except:
            head = False
        
        if head:
            if isinstance(self.model.head, torch.nn.Linear):
                return self.model.head
            else:
                return self.model.head.fc
        else:
            return self.model.fc

    def head_features(self):
        return self.final_linear_layer.in_features
    
    def num_classes(self):
        return self.final_linear_layer.out_features
    
    def head(self, feats):
        return self.model.head((feats,))
    
    def head_matrices(self):
        return self.final_linear_layer.weight, self.final_linear_layer.bias

    def forward(self, x, return_features=False):    
        x = self.transform(x)
        if return_features:
            feats = self.model.forward_features(x)
            logits = self.model.forward_head(feats, pre_logits=True)
            try:
                preds = self.model.fc(logits)  # convnet,
            except:
                preds = self.model.head(logits)  # vit

            return preds, logits
        else:
            return self.model(x) 
        
class MMPretrainModelWrapper(ClsModel):
    """
    Calls data preprocessing for model before entering forward
    """
    def __init__(self, model: nn.Module, dataset_name: str, model_name: str, device: str) -> None:
        super().__init__(dataset_name, model_name, device)
        self.model = model

    @property
    def final_linear_layer(self):
        return self.model.head.fc

    def head_features(self):
        return self.final_linear_layer.in_features
    
    def num_classes(self):
        return self.final_linear_layer.out_features
    
    def head(self, feats):
        return self.model.head((feats,))
    
    def head_matrices(self):
        return self.final_linear_layer.weight, self.final_linear_layer.bias

    def forward(self, x, return_features=False):
        # Data preprocessor expects 0-255 range, but we don't want to cast to proper
        # uint8 because we want to maintain differentiability
        x = x * 255.
        x = self.model.data_preprocessor({"inputs": x})["inputs"]

        if return_features:
            feats = self.model.extract_feat(x)
            
            preds = self.model.head(feats)
            if isinstance(feats, Iterable):
                feats = feats[-1]
                
            return preds, feats
        else:
            return self.model(x)
        
class MMPretrainVisualTransformerWrapper(MMPretrainModelWrapper):
    def __init__(self, model, dataset_name: str, model_name: str, device: str) -> None:
        super().__init__(model, dataset_name, model_name, device)

        attn_layers = []

        def find_mha(m: nn.Module):
            if isinstance(m, MultiheadAttention):
                attn_layers.append(m)

        model.apply(find_mha)

        self.attn_layers = attn_layers

    @property
    def final_linear_layer(self):
        return self.model.head.layers.head
    
    def get_attention_maps(self, x):
        clean_forwards = []

        attention_maps = []

        for attn_layer in self.attn_layers:
            clean_forward = attn_layer.forward
            clean_forwards.append(clean_forward)

            def scaled_dot_prod_attn(query,
                                        key,
                                        value,
                                        attn_mask=None,
                                        dropout_p=0.,
                                        scale=None,
                                        is_causal=False):
                scale = scale or query.size(-1)**0.5
                if is_causal and attn_mask is not None:
                    attn_mask = torch.ones(
                        query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
                if attn_mask is not None and attn_mask.dtype == torch.bool:
                    attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))

                attn_weight = query @ key.transpose(-2, -1) / scale
                if attn_mask is not None:
                    attn_weight += attn_mask
                attn_weight = torch.softmax(attn_weight, dim=-1)

                attention_maps.append(attn_weight)

                attn_weight = torch.dropout(attn_weight, dropout_p, True)
                return attn_weight @ value

            attn_layer.scaled_dot_product_attention = scaled_dot_prod_attn

        ret_val = super().forward(x, False)

        for attn_layer, clean_forward in zip(self.attn_layers, clean_forwards):
            attn_layer.forward = clean_forward

        return attention_maps
    
def register_mmcls_model(config_name, dataset_name, model_name, 
                         wrapper_class=MMPretrainModelWrapper):
    def instantiate_model(device):
        model = mmpretrain.get_model(config_name, pretrained=True, device=device)
        wrapper = wrapper_class(model, dataset_name, model_name, device)
        return wrapper
    
    MODEL_REGISTRY[(dataset_name, model_name)] = instantiate_model

def register_default_models():
    register_mmcls_model("resnet18_8xb16_cifar10", "cifar10", "resnet18")
    register_mmcls_model("resnet34_8xb16_cifar10", "cifar10", "resnet34")
    register_mmcls_model("resnet18_8xb32_in1k", "imagenet", "resnet18")
    register_mmcls_model("resnet50_8xb16_cifar100", "cifar100", "resnet50")
    register_mmcls_model("resnet50_8xb32_in1k", "imagenet", "resnet50")
    register_mmcls_model("densenet121_3rdparty_in1k", "imagenet", "densenet121")

    register_mmcls_model("deit-small_4xb256_in1k", "imagenet", "deit_small",
                          wrapper_class=MMPretrainVisualTransformerWrapper)
    
    register_mmcls_model("vit-base-p16_32xb128-mae_in1k", "imagenet", "vit_base",
                          wrapper_class=MMPretrainVisualTransformerWrapper)

def get_model(dataset_name, model_name, device):
    """
    Returns instance of model pretrained with specified dataset
    """

    return MODEL_REGISTRY[(dataset_name, model_name)](device).eval()