File size: 379 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn


class ChannelCompression(nn.Module):
    """
    Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs.

    In: HxWxC
    Out: HxWx2
    """

    def forward(self, x):
        return torch.cat(
            (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
        )