File size: 3,859 Bytes
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn

from models import register
from torch.nn import TransformerEncoderLayer, TransformerEncoder, LayerNorm
from torch.nn.init import xavier_uniform_
from einops import rearrange
# from mmcv.runner.base_module import BaseModule, ModuleList
# from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
# from mmcv.cnn.utils.weight_init import trunc_normal_
# from mmcv.cnn.bricks.registry import DROPOUT_LAYERS


@register('transformer_neck')
class TransformerNeck(nn.Module):
    def __init__(self,
                 in_dim,
                 d_dim=256,
                 downsample=True,
                 has_pe=True,
                 has_norm=True,
                 class_token=True,
                 num_encoder_layers=3,
                 dim_feedforward=512,
                 drop_rate=0.1
                 ):
        super().__init__()
        self.input_proj = nn.Conv2d(in_dim, d_dim, kernel_size=1)
        self.downsample = downsample

        if self.downsample:
            self.sampler = nn.Conv2d(d_dim, d_dim, kernel_size=3, stride=2, padding=1)

        encoder_layer = TransformerEncoderLayer(
            d_model=d_dim,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=drop_rate,
            activation='gelu',
            batch_first=True
        )
        if has_norm:
            encoder_norm = LayerNorm(d_dim)
        else:
            encoder_norm = None
        self.trans_encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        if self.downsample:
            self.uplayer = nn.Sequential(
                nn.Conv2d(d_dim, d_dim*4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.Conv2d(d_dim, d_dim, kernel_size=3, padding=1)
            )

        # for p in self.parameters():
        #     if p.dim() > 1:
        #         xavier_uniform_(p)

        self.d_dim = d_dim
        self.class_token = class_token
        self.has_pe = has_pe

        if self.class_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_dim))

        if self.has_pe:
            self.pos_embed = nn.Parameter(
                torch.randn(1, d_dim, 24, 24) * 0.02
            )
            self.drop_after_pos = nn.Dropout(p=drop_rate)

    def forward(self, x):
        # x: List
        x = [self.input_proj(x_tmp) for x_tmp in x]
        if self.downsample:
            x[0] = self.sampler(x[0])
        B, C, H, W = x[0].shape

        if self.has_pe:
            assert W == self.pos_embed.shape[-1]
            x_with_pe = []
            for x_tmp in x:
                b_tmp, c_tmp, h_tmp, w_tmp = x_tmp.shape
                pe = nn.functional.interpolate(self.pos_embed, size=[h_tmp, w_tmp], mode='bicubic', align_corners=True)
                x_tmp = x_tmp + pe
                x_with_pe.append(x_tmp)
        else:
            x_with_pe = x

        x_flatten = []
        for x in x_with_pe:
            x = rearrange(x, 'b c h w -> b (h w) c')
            x_flatten.append(x)
        x_flatten = torch.cat(x_flatten, dim=1)

        if self.has_pe:
            x_flatten = self.drop_after_pos(x_flatten)

        if self.class_token:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x_flatten), dim=1)
        else:
            x = x_flatten

        x = self.trans_encoder(x)

        if self.class_token:
            global_content = x[:, 0]
            func_map = x[:, 1:(1 + H * W), :]
        else:
            global_content = x.mean(dim=1)  # global pool without cls token
            func_map = x[:, :(H * W), :]

        func_map = rearrange(func_map, 'b (h w) c -> b c h w', h=H)
        if self.downsample:
            func_map = self.uplayer(func_map)
        return global_content, func_map