rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
11.6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta
from typing import Dict, Optional
import torch
import torch.nn as nn
from mmengine.model import constant_init, normal_init
from mmengine.registry import MODELS
from .conv_module import ConvModule
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
"""Basic Non-local module.
This module is proposed in
"Non-local Neural Networks"
Paper reference: https://arxiv.org/abs/1711.07971
Code reference: https://github.com/AlexHex7/Non-local_pytorch
Args:
in_channels (int): Channels of the input feature map.
reduction (int): Channel reduction ratio. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
Default: True.
conv_cfg (None | dict): The config dict for convolution layers.
If not specified, it will use `nn.Conv2d` for convolution layers.
Default: None.
norm_cfg (None | dict): The config dict for normalization layers.
Default: None. (This parameter is only applicable to conv_out.)
mode (str): Options are `gaussian`, `concatenation`,
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
"""
def __init__(self,
in_channels: int,
reduction: int = 2,
use_scale: bool = True,
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
mode: str = 'embedded_gaussian',
**kwargs):
super().__init__()
self.in_channels = in_channels
self.reduction = reduction
self.use_scale = use_scale
self.inter_channels = max(in_channels // reduction, 1)
self.mode = mode
if mode not in [
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
]:
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
f"'embedded_gaussian' or 'dot_product', but got "
f'{mode} instead.')
# g, theta, phi are defaulted as `nn.ConvNd`.
# Here we use ConvModule for potential usage.
self.g = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None) # type: ignore
self.conv_out = ConvModule(
self.inter_channels,
self.in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
if self.mode != 'gaussian':
self.theta = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
self.phi = ConvModule(
self.in_channels,
self.inter_channels,
kernel_size=1,
conv_cfg=conv_cfg,
act_cfg=None)
if self.mode == 'concatenation':
self.concat_project = ConvModule(
self.inter_channels * 2,
1,
kernel_size=1,
stride=1,
padding=0,
bias=False,
act_cfg=dict(type='ReLU'))
self.init_weights(**kwargs)
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std)
else:
normal_init(self.g.conv, std=std)
if zeros_init:
if self.conv_out.norm_cfg is None:
constant_init(self.conv_out.conv, 0)
else:
constant_init(self.conv_out.norm, 0)
else:
if self.conv_out.norm_cfg is None:
normal_init(self.conv_out.conv, std=std)
else:
normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def embedded_gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def dot_product(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def concatenation(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
pairwise_weight = self.concat_project(concat_feature)
n, _, h, w = pairwise_weight.size()
pairwise_weight = pairwise_weight.view(n, h, w)
pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"`
# NonLocal1d x: [N, C, H]
# NonLocal2d x: [N, C, H, W]
# NonLocal3d x: [N, C, T, H, W]
n = x.size(0)
# NonLocal1d g_x: [N, H, C]
# NonLocal2d g_x: [N, HxW, C]
# NonLocal3d g_x: [N, TxHxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
pairwise_func = getattr(self, self.mode)
# NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# NonLocal1d y: [N, H, C]
# NonLocal2d y: [N, HxW, C]
# NonLocal3d y: [N, TxHxW, C]
y = torch.matmul(pairwise_weight, g_x)
# NonLocal1d y: [N, C, H]
# NonLocal2d y: [N, C, H, W]
# NonLocal3d y: [N, C, T, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
output = x + self.conv_out(y)
return output
class NonLocal1d(_NonLocalNd):
"""1D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv1d').
"""
def __init__(self,
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv1d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool1d(kernel_size=2)
self.g = nn.Sequential(self.g, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
@MODELS.register_module()
class NonLocal2d(_NonLocalNd):
"""2D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv2d').
"""
_abbr_ = 'nonlocal_block'
def __init__(self,
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv2d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer
class NonLocal3d(_NonLocalNd):
"""3D Non-local module.
Args:
in_channels (int): Same as `NonLocalND`.
sub_sample (bool): Whether to apply max pooling after pairwise
function (Note that the `sub_sample` is applied on spatial only).
Default: False.
conv_cfg (None | dict): Same as `NonLocalND`.
Default: dict(type='Conv3d').
"""
def __init__(self,
in_channels: int,
sub_sample: bool = False,
conv_cfg: Dict = dict(type='Conv3d'),
**kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample
if sub_sample:
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
self.g = nn.Sequential(self.g, max_pool_layer)
if self.mode != 'gaussian':
self.phi = nn.Sequential(self.phi, max_pool_layer)
else:
self.phi = max_pool_layer