File size: 8,803 Bytes
d2ff88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9199a4f
 
d2ff88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# ------------------------------------------------------------------------------
# CLIP-DINOiser
# author: Monika Wysoczanska, Warsaw University of Technology
# ------------------------------------------------------------------------------
# Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP
# ------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.ops import resize
from typing import Any, List
from torch import Tensor
from mmcv.utils import print_log
from mmseg.utils import get_root_logger
from open_clip import get_tokenizer,  create_model_from_pretrained
from models.builder import MODELS
from .vit import VisionTransformer
import torchvision.transforms as T
from .utils.embed import AdaptivePadding
from .utils.prompt_templates import imagenet_templates

OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))


def make_vision_transformer(backbone_cfg):
    model = VisionTransformer(**backbone_cfg)
    model.init_weights()
    return model


@MODELS.register_module()
class MaskClip(nn.Module):
    def __init__(
            self,
            backbone,
            decode_head,
            clip_model,
            class_names
        ):
        super(MaskClip, self).__init__()

        self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head)
        self.backbone = make_vision_transformer(backbone)
        self.clip_T = OPENAI_NORMALIZE

        self.to_PIL = T.ToPILImage()
        self.patch_size = backbone.get('patch_size')
        self.padding = AdaptivePadding(self.patch_size, self.patch_size)

    def extract_feat(self, inputs: Tensor) -> List[Tensor]:
        """Extract features from images."""
        x = self.backbone(inputs)
        return x

    def forward(self, inputs: Tensor, return_feat=False) -> Tensor:
        """Encode images with backbone and decode into a semantic segmentation
        map of the same size as input."""
        inputs = self.clip_T(inputs)
        x = self.extract_feat(inputs)

        seg_logits, feats, k = self.decode_head(x, return_feat)

        if return_feat:
            return seg_logits, feats, k
        return seg_logits

class MaskClipHead(nn.Module):
    def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0,
                 text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs):
        super(MaskClipHead, self).__init__()

        self.text_channels = text_channels
        self.visual_projs_path = visual_projs_path
        self.clip_model = clip_model
        self.class_names = class_names
        self.in_channels = in_channels
        self.in_index = in_index # from base decode head default
        self._init_inputs(in_channels, in_index, None)
        self.channels = channels
        self.norm_cfg = norm_cfg
        self.align_corners = align_corners
        self.use_templates = use_templates

        self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False)
        self.load_visual_projs()

        self.attn_pooling = attn_pooling
        self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}')
        self.hf_modelname = f'{model_prefix}/{clip_model}'
        model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}')
        model.eval()
        self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names))

    @torch.no_grad()
    def update_vocab(self, class_names):
        model, _ = create_model_from_pretrained(self.hf_modelname)
        model.eval()
        self.class_embeddings = self._get_class_embeddings(model, class_names)

    @torch.no_grad()
    def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor:
        """
        Encode label name into a single vector
        """
        if self.use_templates:
            templates = imagenet_templates
        else:
            templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}']

        all_prompts = [self.tokenizer(template.format(label)) for template in templates]
        out = text_model.encode_text(torch.cat(all_prompts))
        out /= out.norm(dim=-1, keepdim=True)
        out = out.mean(dim=0)
        return out

    def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]):
        aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names])
        # normalize vector
        aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
        return aug_embeddings.squeeze(1)

    def load_visual_projs(self):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        loaded = torch.load(self.visual_projs_path, map_location=device)
        attrs = ['proj']
        for attr in attrs:
            current_attr = getattr(self, attr)
            state_dict = loaded[attr]
            for key in state_dict:
                if 'weight' in key:
                    state_dict[key] = state_dict[key][:, :, None, None]
            current_attr.load_state_dict(state_dict)
        print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())

    def forward(self, inputs, return_feat=False):
        x = self._transform_inputs(inputs)
        q, k, v, cls_token = None, None, None, None
        if isinstance(x, list) and len(x) == 4:
            x, q, k, v = x
        if isinstance(x, list) and len(x) == 2:
            x, cls_token = x
        if v is not None:
            feat = self.proj(v)
        else:
            feat = self.proj(x)
        output = self.cls_seg(feat)
        if return_feat:
            return output, feat, k

        return output

    def _init_inputs(self, in_channels, in_index, input_transform):
        """Check and initialize input transforms.

        The in_channels, in_index and input_transform must match.
        Specifically, when input_transform is None, only single feature map
        will be selected. So in_channels and in_index must be of type int.
        When input_transform

        Args:
            in_channels (int|Sequence[int]): Input channels.
            in_index (int|Sequence[int]): Input feature index.
            input_transform (str|None): Transformation type of input features.
                Options: 'resize_concat', 'multiple_select', None.
                'resize_concat': Multiple feature maps will be resize to the
                    same size as first one and than concat together.
                    Usually used in FCN head of HRNet.
                'multiple_select': Multiple feature maps will be bundle into
                    a list and passed into decode head.
                None: Only one select feature map is allowed.
        """

        if input_transform is not None:
            assert input_transform in ['resize_concat', 'multiple_select']
        self.input_transform = input_transform
        self.in_index = in_index
        if input_transform is not None:
            assert isinstance(in_channels, (list, tuple))
            assert isinstance(in_index, (list, tuple))
            assert len(in_channels) == len(in_index)
            if input_transform == 'resize_concat':
                self.in_channels = sum(in_channels)
            else:
                self.in_channels = in_channels
        else:
            assert isinstance(in_channels, int)
            assert isinstance(in_index, int)
            self.in_channels = in_channels

    def cls_seg(self, feat):
        feat = feat / feat.norm(dim=1, keepdim=True)
        output = F.conv2d(feat, self.class_embeddings[:, :, None, None])
        output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling

        return output

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """
        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs