# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Union import torch from mmengine.model import constant_init, kaiming_init from mmengine.registry import MODELS from torch import nn def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None: if isinstance(m, nn.Sequential): constant_init(m[-1], val=0) else: constant_init(m, val=0) @MODELS.register_module() class ContextBlock(nn.Module): """ContextBlock module in GCNet. See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond' (https://arxiv.org/abs/1904.11492) for details. Args: in_channels (int): Channels of the input feature map. ratio (float): Ratio of channels of transform bottleneck pooling_type (str): Pooling method for context modeling. Options are 'att' and 'avg', stand for attention pooling and average pooling respectively. Default: 'att'. fusion_types (Sequence[str]): Fusion method for feature fusion, Options are 'channels_add', 'channel_mul', stand for channelwise addition and multiplication respectively. Default: ('channel_add',) """ _abbr_ = 'context_block' def __init__(self, in_channels: int, ratio: float, pooling_type: str = 'att', fusion_types: tuple = ('channel_add', )): super().__init__() assert pooling_type in ['avg', 'att'] assert isinstance(fusion_types, (list, tuple)) valid_fusion_types = ['channel_add', 'channel_mul'] assert all([f in valid_fusion_types for f in fusion_types]) assert len(fusion_types) > 0, 'at least one fusion should be used' self.in_channels = in_channels self.ratio = ratio self.planes = int(in_channels * ratio) self.pooling_type = pooling_type self.fusion_types = fusion_types if pooling_type == 'att': self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1) self.softmax = nn.Softmax(dim=2) else: self.avg_pool = nn.AdaptiveAvgPool2d(1) if 'channel_add' in fusion_types: self.channel_add_conv = nn.Sequential( nn.Conv2d(self.in_channels, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) else: self.channel_add_conv = None if 'channel_mul' in fusion_types: self.channel_mul_conv = nn.Sequential( nn.Conv2d(self.in_channels, self.planes, kernel_size=1), nn.LayerNorm([self.planes, 1, 1]), nn.ReLU(inplace=True), # yapf: disable nn.Conv2d(self.planes, self.in_channels, kernel_size=1)) else: self.channel_mul_conv = None self.reset_parameters() def reset_parameters(self): if self.pooling_type == 'att': kaiming_init(self.conv_mask, mode='fan_in') self.conv_mask.inited = True if self.channel_add_conv is not None: last_zero_init(self.channel_add_conv) if self.channel_mul_conv is not None: last_zero_init(self.channel_mul_conv) def spatial_pool(self, x: torch.Tensor) -> torch.Tensor: batch, channel, height, width = x.size() if self.pooling_type == 'att': input_x = x # [N, C, H * W] input_x = input_x.view(batch, channel, height * width) # [N, 1, C, H * W] input_x = input_x.unsqueeze(1) # [N, 1, H, W] context_mask = self.conv_mask(x) # [N, 1, H * W] context_mask = context_mask.view(batch, 1, height * width) # [N, 1, H * W] context_mask = self.softmax(context_mask) # [N, 1, H * W, 1] context_mask = context_mask.unsqueeze(-1) # [N, 1, C, 1] context = torch.matmul(input_x, context_mask) # [N, C, 1, 1] context = context.view(batch, channel, 1, 1) else: # [N, C, 1, 1] context = self.avg_pool(x) return context def forward(self, x: torch.Tensor) -> torch.Tensor: # [N, C, 1, 1] context = self.spatial_pool(x) out = x if self.channel_mul_conv is not None: # [N, C, 1, 1] channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) out = out * channel_mul_term if self.channel_add_conv is not None: # [N, C, 1, 1] channel_add_term = self.channel_add_conv(context) out = out + channel_add_term return out