File size: 2,453 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn


class SelectiveKernelFeatureFusion(nn.Module):
    """
    Merges outputs of the three different resolutions through self-attention.

    All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors,
    softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps.
    """

    def __init__(self, in_channels, reduction_ratio, bias=False):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        conv_out_channels = max(int(in_channels / reduction_ratio), 4)
        self.convolution = nn.Sequential(
            nn.Conv2d(
                in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias
            ),
            nn.PReLU(),
        )

        self.attention_convs = nn.ModuleList([])
        for i in range(3):
            self.attention_convs.append(
                nn.Conv2d(
                    conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias
                )
            )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x[0].shape[0]
        n_features = x[0].shape[1]

        x = torch.cat(
            x, dim=1
        )  # the three outputs of diff. res. are concatenated along the channel dimension
        x = x.view(
            batch_size, 3, n_features, x.shape[2], x.shape[3]
        )  # batch_size x 3 x n_features x H x W

        z = torch.sum(x, dim=1)  # batch_size x n_features x H x W
        z = self.avg_pool(z)  # batch_size x n_features x 1 x 1
        z = self.convolution(z)  # batch_size x n_features/8 x 1 x 1

        attention_activations = [
            atn(z) for atn in self.attention_convs
        ]  # 3 x (batch_size x n_features x 1 x 1)
        attention_activations = torch.cat(
            attention_activations, dim=1
        )  # batch_size x 3*n_features x 1 x 1
        attention_activations = attention_activations.view(
            batch_size, 3, n_features, 1, 1
        )  # batch_size x 3 x n_features x 1 x 1

        attention_activations = self.softmax(
            attention_activations
        )  # batch_size x 3 x n_features x 1 x 1

        return torch.sum(
            x * attention_activations, dim=1
        )  # batch_size x n_features x H x W (the three feature maps are recalibrated and summed