# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABCMeta from typing import Dict, Optional import torch import torch.nn as nn from mmengine.model import constant_init, normal_init from mmengine.registry import MODELS from .conv_module import ConvModule class _NonLocalNd(nn.Module, metaclass=ABCMeta): """Basic Non-local module. This module is proposed in "Non-local Neural Networks" Paper reference: https://arxiv.org/abs/1711.07971 Code reference: https://github.com/AlexHex7/Non-local_pytorch Args: in_channels (int): Channels of the input feature map. reduction (int): Channel reduction ratio. Default: 2. use_scale (bool): Whether to scale pairwise_weight by `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. Default: True. conv_cfg (None | dict): The config dict for convolution layers. If not specified, it will use `nn.Conv2d` for convolution layers. Default: None. norm_cfg (None | dict): The config dict for normalization layers. Default: None. (This parameter is only applicable to conv_out.) mode (str): Options are `gaussian`, `concatenation`, `embedded_gaussian` and `dot_product`. Default: embedded_gaussian. """ def __init__(self, in_channels: int, reduction: int = 2, use_scale: bool = True, conv_cfg: Optional[Dict] = None, norm_cfg: Optional[Dict] = None, mode: str = 'embedded_gaussian', **kwargs): super().__init__() self.in_channels = in_channels self.reduction = reduction self.use_scale = use_scale self.inter_channels = max(in_channels // reduction, 1) self.mode = mode if mode not in [ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation' ]: raise ValueError("Mode should be in 'gaussian', 'concatenation', " f"'embedded_gaussian' or 'dot_product', but got " f'{mode} instead.') # g, theta, phi are defaulted as `nn.ConvNd`. # Here we use ConvModule for potential usage. self.g = ConvModule( self.in_channels, self.inter_channels, kernel_size=1, conv_cfg=conv_cfg, act_cfg=None) # type: ignore self.conv_out = ConvModule( self.inter_channels, self.in_channels, kernel_size=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) if self.mode != 'gaussian': self.theta = ConvModule( self.in_channels, self.inter_channels, kernel_size=1, conv_cfg=conv_cfg, act_cfg=None) self.phi = ConvModule( self.in_channels, self.inter_channels, kernel_size=1, conv_cfg=conv_cfg, act_cfg=None) if self.mode == 'concatenation': self.concat_project = ConvModule( self.inter_channels * 2, 1, kernel_size=1, stride=1, padding=0, bias=False, act_cfg=dict(type='ReLU')) self.init_weights(**kwargs) def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None: if self.mode != 'gaussian': for m in [self.g, self.theta, self.phi]: normal_init(m.conv, std=std) else: normal_init(self.g.conv, std=std) if zeros_init: if self.conv_out.norm_cfg is None: constant_init(self.conv_out.conv, 0) else: constant_init(self.conv_out.norm, 0) else: if self.conv_out.norm_cfg is None: normal_init(self.conv_out.conv, std=std) else: normal_init(self.conv_out.norm, std=std) def gaussian(self, theta_x: torch.Tensor, phi_x: torch.Tensor) -> torch.Tensor: # NonLocal1d pairwise_weight: [N, H, H] # NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] pairwise_weight = torch.matmul(theta_x, phi_x) pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight def embedded_gaussian(self, theta_x: torch.Tensor, phi_x: torch.Tensor) -> torch.Tensor: # NonLocal1d pairwise_weight: [N, H, H] # NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] pairwise_weight = torch.matmul(theta_x, phi_x) if self.use_scale: # theta_x.shape[-1] is `self.inter_channels` pairwise_weight /= theta_x.shape[-1]**0.5 pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight def dot_product(self, theta_x: torch.Tensor, phi_x: torch.Tensor) -> torch.Tensor: # NonLocal1d pairwise_weight: [N, H, H] # NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] pairwise_weight = torch.matmul(theta_x, phi_x) pairwise_weight /= pairwise_weight.shape[-1] return pairwise_weight def concatenation(self, theta_x: torch.Tensor, phi_x: torch.Tensor) -> torch.Tensor: # NonLocal1d pairwise_weight: [N, H, H] # NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] h = theta_x.size(2) w = phi_x.size(3) theta_x = theta_x.repeat(1, 1, 1, w) phi_x = phi_x.repeat(1, 1, h, 1) concat_feature = torch.cat([theta_x, phi_x], dim=1) pairwise_weight = self.concat_project(concat_feature) n, _, h, w = pairwise_weight.size() pairwise_weight = pairwise_weight.view(n, h, w) pairwise_weight /= pairwise_weight.shape[-1] return pairwise_weight def forward(self, x: torch.Tensor) -> torch.Tensor: # Assume `reduction = 1`, then `inter_channels = C` # or `inter_channels = C` when `mode="gaussian"` # NonLocal1d x: [N, C, H] # NonLocal2d x: [N, C, H, W] # NonLocal3d x: [N, C, T, H, W] n = x.size(0) # NonLocal1d g_x: [N, H, C] # NonLocal2d g_x: [N, HxW, C] # NonLocal3d g_x: [N, TxHxW, C] g_x = self.g(x).view(n, self.inter_channels, -1) g_x = g_x.permute(0, 2, 1) # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H] # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW] # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW] if self.mode == 'gaussian': theta_x = x.view(n, self.in_channels, -1) theta_x = theta_x.permute(0, 2, 1) if self.sub_sample: phi_x = self.phi(x).view(n, self.in_channels, -1) else: phi_x = x.view(n, self.in_channels, -1) elif self.mode == 'concatenation': theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) else: theta_x = self.theta(x).view(n, self.inter_channels, -1) theta_x = theta_x.permute(0, 2, 1) phi_x = self.phi(x).view(n, self.inter_channels, -1) pairwise_func = getattr(self, self.mode) # NonLocal1d pairwise_weight: [N, H, H] # NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] pairwise_weight = pairwise_func(theta_x, phi_x) # NonLocal1d y: [N, H, C] # NonLocal2d y: [N, HxW, C] # NonLocal3d y: [N, TxHxW, C] y = torch.matmul(pairwise_weight, g_x) # NonLocal1d y: [N, C, H] # NonLocal2d y: [N, C, H, W] # NonLocal3d y: [N, C, T, H, W] y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, *x.size()[2:]) output = x + self.conv_out(y) return output class NonLocal1d(_NonLocalNd): """1D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv1d'). """ def __init__(self, in_channels: int, sub_sample: bool = False, conv_cfg: Dict = dict(type='Conv1d'), **kwargs): super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool1d(kernel_size=2) self.g = nn.Sequential(self.g, max_pool_layer) if self.mode != 'gaussian': self.phi = nn.Sequential(self.phi, max_pool_layer) else: self.phi = max_pool_layer @MODELS.register_module() class NonLocal2d(_NonLocalNd): """2D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv2d'). """ _abbr_ = 'nonlocal_block' def __init__(self, in_channels: int, sub_sample: bool = False, conv_cfg: Dict = dict(type='Conv2d'), **kwargs): super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) self.g = nn.Sequential(self.g, max_pool_layer) if self.mode != 'gaussian': self.phi = nn.Sequential(self.phi, max_pool_layer) else: self.phi = max_pool_layer class NonLocal3d(_NonLocalNd): """3D Non-local module. Args: in_channels (int): Same as `NonLocalND`. sub_sample (bool): Whether to apply max pooling after pairwise function (Note that the `sub_sample` is applied on spatial only). Default: False. conv_cfg (None | dict): Same as `NonLocalND`. Default: dict(type='Conv3d'). """ def __init__(self, in_channels: int, sub_sample: bool = False, conv_cfg: Dict = dict(type='Conv3d'), **kwargs): super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) self.sub_sample = sub_sample if sub_sample: max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.g = nn.Sequential(self.g, max_pool_layer) if self.mode != 'gaussian': self.phi = nn.Sequential(self.phi, max_pool_layer) else: self.phi = max_pool_layer