File size: 1,223 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    """
    Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.

    In: HxWxC
    Out: HxWxC (original channels are restored by multiplying the output with the original input)
    """

    def __init__(self, in_channels, reduction_ratio=8, bias=True):
        super().__init__()
        self.squeezing = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Conv2d(
                in_channels,
                in_channels // reduction_ratio,
                kernel_size=1,
                padding=0,
                bias=bias,
            ),
            nn.PReLU(),
            nn.Conv2d(
                in_channels // reduction_ratio,
                in_channels,
                kernel_size=1,
                padding=0,
                bias=bias,
            ),
            nn.Sigmoid(),
        )

    def forward(self, x):
        squeezed_x = self.squeezing(x)  # 1x1xC
        excitation = self.excitation(squeezed_x)  # 1x1x(C/r)
        return (
            excitation * x
        )  # HxWxC restored through the mult. with the original input