File size: 4,889 Bytes
d2ff88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# ---------------------------------------------------------------------------------------------------
# CLIP-DINOiser
# authors: Monika Wysoczanska, Warsaw University of Technology & Oriane Simeoni, valeo.ai
# ---------------------------------------------------------------------------------------------------
import torch.nn as nn
from models.builder import MODELS
from models.builder import build_model
import torch
import torchvision.transforms as T
from omegaconf import OmegaConf
import torch.nn.functional as F

NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))


@MODELS.register_module()
class CLIP_DINOiser(nn.Module):
    def __init__(self, clip_backbone, class_names, mask_th=None, found_th=0.5, certainty_th=0.9, apply_found=False,
                 in_dim=256, conv_kernel=3, feats_idx=-3):

        super(CLIP_DINOiser, self).__init__()
        self.mask_th = mask_th
        self.apply_found = apply_found
        self.found_th = found_th
        self.certainty_th = certainty_th
        self.sigmoid = nn.Sigmoid()
        maskclip_cfg = OmegaConf.load(f"configs/{clip_backbone}.yaml")
        self.clip_backbone = build_model(maskclip_cfg["model"], class_names=class_names)
        self.vit_patch_size = self.clip_backbone.patch_size
        self.feats_idx = feats_idx
        self.in_dim = [in_dim]
        in_size = 768 if self.feats_idx != 'final' else 512
        self.bkg_decoder = nn.Conv2d(in_size, 1, (1, 1))
        self.obj_proj = nn.Conv2d(in_size, in_dim, (conv_kernel, conv_kernel),
                                      padding=conv_kernel // 2, padding_mode='replicate')

        # setup clip feature for training
        if feats_idx != 'final':
            train_feats = {}
            def get_activation(name):
                def hook(model, input, output):
                    train_feats[name] = output.detach()
                return hook
            self.clip_backbone.backbone.layers[feats_idx].ln2.register_forward_hook(get_activation('clip_inter'))
            self.train_feats = train_feats


    def forward_pass(self, x):
        clip_feats = self.get_clip_map(x)[0]
        B, c_dim, h, w = clip_feats.shape
        _, _, H, W = x.shape
        if self.feats_idx != 'final':
            clip_feats = self.train_feats['clip_inter']
            c_dim = clip_feats.shape[-1]
            clip_feats = clip_feats[:, 1:, ].permute(0, 2, 1).reshape(B, c_dim, h, w)

        proj_feats = self.obj_proj(clip_feats).reshape(B, self.in_dim[-1], -1)
        proj_feats = proj_feats / proj_feats.norm(dim=1, keepdim=True)

        corrs = torch.matmul(proj_feats.permute(0, 2, 1), proj_feats).reshape(B,h*w, h, w)
        output = clip_feats / clip_feats.norm(dim=1, keepdim=True)
        output = self.bkg_decoder(output)

        return output, corrs

    def forward(self, x):
        preds, corrs = self.forward_pass(x)
        output, _, _ = self.get_clip_map(x)
        B, C, hf, wf = output.shape
        preds = F.interpolate(preds, (hf, wf), mode="bilinear", align_corners=False )

        # Compute weighted pooling
        if self.mask_th:
             corrs[corrs < self.mask_th] = 0.0
        output = self.compute_weighted_pool(output, corrs)
        output = output.reshape(B, C, hf, wf)
        output = self.clip_backbone.decode_head.cls_seg(output)

        if self.apply_found:
            # Compute FOUND --------------------------------------------------
            soft_found = self.sigmoid(preds.detach())
            r_soft_found = soft_found.reshape(-1)
            nb_cls = output.shape[1]
            r_hard_found = (r_soft_found > self.found_th).float()

            # TODO: make it work for Batch Size != 1
            uncertain = (output.max(dim=1)[0] < self.certainty_th).reshape(-1)
            output.reshape(1, nb_cls, -1)[:, 0, uncertain & (~r_hard_found.bool())] = 1.0  # background class

        return output

    def predict(self, x):
        return self(x)

    @torch.no_grad()
    def get_clip_map(self, img):
        maskclip_map, feat, k = self.clip_backbone(img, return_feat=True)

        return feat, k, maskclip_map

    @torch.no_grad()
    def compute_weighted_pool(self, clipmap, corrs):
        # upsampling
        B = clipmap.shape[0]
        h_m, w_m = clipmap.shape[-2:]
        h_w, w_w = corrs.shape[-2:]

        if (h_m != h_w) or (w_m != w_w):
            clipmap = F.interpolate(clipmap, (h_w, w_w), mode="bilinear", align_corners=False )
            h_m, w_m = h_w, w_w

        corrs[corrs < 0.0] = 0.0   # B HW H W
        clipmap_refined = torch.einsum("bnij, bcij -> bcn", corrs, clipmap)  # B C HW
        norm_factor = corrs.flatten(-2, -1).sum(dim=-1)[:, None] # B 1 HW
        clipmap_refined = clipmap_refined / (norm_factor + 1e-6)

        # RESHAPE back to 2d
        clipmap_refined = clipmap_refined.reshape(B, -1, h_m, w_m)

        return clipmap_refined