Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import ABCMeta | |
from typing import Dict, Optional | |
import torch | |
import torch.nn as nn | |
from mmengine.model import constant_init, normal_init | |
from mmengine.registry import MODELS | |
from .conv_module import ConvModule | |
class _NonLocalNd(nn.Module, metaclass=ABCMeta): | |
"""Basic Non-local module. | |
This module is proposed in | |
"Non-local Neural Networks" | |
Paper reference: https://arxiv.org/abs/1711.07971 | |
Code reference: https://github.com/AlexHex7/Non-local_pytorch | |
Args: | |
in_channels (int): Channels of the input feature map. | |
reduction (int): Channel reduction ratio. Default: 2. | |
use_scale (bool): Whether to scale pairwise_weight by | |
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`. | |
Default: True. | |
conv_cfg (None | dict): The config dict for convolution layers. | |
If not specified, it will use `nn.Conv2d` for convolution layers. | |
Default: None. | |
norm_cfg (None | dict): The config dict for normalization layers. | |
Default: None. (This parameter is only applicable to conv_out.) | |
mode (str): Options are `gaussian`, `concatenation`, | |
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian. | |
""" | |
def __init__(self, | |
in_channels: int, | |
reduction: int = 2, | |
use_scale: bool = True, | |
conv_cfg: Optional[Dict] = None, | |
norm_cfg: Optional[Dict] = None, | |
mode: str = 'embedded_gaussian', | |
**kwargs): | |
super().__init__() | |
self.in_channels = in_channels | |
self.reduction = reduction | |
self.use_scale = use_scale | |
self.inter_channels = max(in_channels // reduction, 1) | |
self.mode = mode | |
if mode not in [ | |
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation' | |
]: | |
raise ValueError("Mode should be in 'gaussian', 'concatenation', " | |
f"'embedded_gaussian' or 'dot_product', but got " | |
f'{mode} instead.') | |
# g, theta, phi are defaulted as `nn.ConvNd`. | |
# Here we use ConvModule for potential usage. | |
self.g = ConvModule( | |
self.in_channels, | |
self.inter_channels, | |
kernel_size=1, | |
conv_cfg=conv_cfg, | |
act_cfg=None) # type: ignore | |
self.conv_out = ConvModule( | |
self.inter_channels, | |
self.in_channels, | |
kernel_size=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
if self.mode != 'gaussian': | |
self.theta = ConvModule( | |
self.in_channels, | |
self.inter_channels, | |
kernel_size=1, | |
conv_cfg=conv_cfg, | |
act_cfg=None) | |
self.phi = ConvModule( | |
self.in_channels, | |
self.inter_channels, | |
kernel_size=1, | |
conv_cfg=conv_cfg, | |
act_cfg=None) | |
if self.mode == 'concatenation': | |
self.concat_project = ConvModule( | |
self.inter_channels * 2, | |
1, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
act_cfg=dict(type='ReLU')) | |
self.init_weights(**kwargs) | |
def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None: | |
if self.mode != 'gaussian': | |
for m in [self.g, self.theta, self.phi]: | |
normal_init(m.conv, std=std) | |
else: | |
normal_init(self.g.conv, std=std) | |
if zeros_init: | |
if self.conv_out.norm_cfg is None: | |
constant_init(self.conv_out.conv, 0) | |
else: | |
constant_init(self.conv_out.norm, 0) | |
else: | |
if self.conv_out.norm_cfg is None: | |
normal_init(self.conv_out.conv, std=std) | |
else: | |
normal_init(self.conv_out.norm, std=std) | |
def gaussian(self, theta_x: torch.Tensor, | |
phi_x: torch.Tensor) -> torch.Tensor: | |
# NonLocal1d pairwise_weight: [N, H, H] | |
# NonLocal2d pairwise_weight: [N, HxW, HxW] | |
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] | |
pairwise_weight = torch.matmul(theta_x, phi_x) | |
pairwise_weight = pairwise_weight.softmax(dim=-1) | |
return pairwise_weight | |
def embedded_gaussian(self, theta_x: torch.Tensor, | |
phi_x: torch.Tensor) -> torch.Tensor: | |
# NonLocal1d pairwise_weight: [N, H, H] | |
# NonLocal2d pairwise_weight: [N, HxW, HxW] | |
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] | |
pairwise_weight = torch.matmul(theta_x, phi_x) | |
if self.use_scale: | |
# theta_x.shape[-1] is `self.inter_channels` | |
pairwise_weight /= theta_x.shape[-1]**0.5 | |
pairwise_weight = pairwise_weight.softmax(dim=-1) | |
return pairwise_weight | |
def dot_product(self, theta_x: torch.Tensor, | |
phi_x: torch.Tensor) -> torch.Tensor: | |
# NonLocal1d pairwise_weight: [N, H, H] | |
# NonLocal2d pairwise_weight: [N, HxW, HxW] | |
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] | |
pairwise_weight = torch.matmul(theta_x, phi_x) | |
pairwise_weight /= pairwise_weight.shape[-1] | |
return pairwise_weight | |
def concatenation(self, theta_x: torch.Tensor, | |
phi_x: torch.Tensor) -> torch.Tensor: | |
# NonLocal1d pairwise_weight: [N, H, H] | |
# NonLocal2d pairwise_weight: [N, HxW, HxW] | |
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] | |
h = theta_x.size(2) | |
w = phi_x.size(3) | |
theta_x = theta_x.repeat(1, 1, 1, w) | |
phi_x = phi_x.repeat(1, 1, h, 1) | |
concat_feature = torch.cat([theta_x, phi_x], dim=1) | |
pairwise_weight = self.concat_project(concat_feature) | |
n, _, h, w = pairwise_weight.size() | |
pairwise_weight = pairwise_weight.view(n, h, w) | |
pairwise_weight /= pairwise_weight.shape[-1] | |
return pairwise_weight | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# Assume `reduction = 1`, then `inter_channels = C` | |
# or `inter_channels = C` when `mode="gaussian"` | |
# NonLocal1d x: [N, C, H] | |
# NonLocal2d x: [N, C, H, W] | |
# NonLocal3d x: [N, C, T, H, W] | |
n = x.size(0) | |
# NonLocal1d g_x: [N, H, C] | |
# NonLocal2d g_x: [N, HxW, C] | |
# NonLocal3d g_x: [N, TxHxW, C] | |
g_x = self.g(x).view(n, self.inter_channels, -1) | |
g_x = g_x.permute(0, 2, 1) | |
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H] | |
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW] | |
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW] | |
if self.mode == 'gaussian': | |
theta_x = x.view(n, self.in_channels, -1) | |
theta_x = theta_x.permute(0, 2, 1) | |
if self.sub_sample: | |
phi_x = self.phi(x).view(n, self.in_channels, -1) | |
else: | |
phi_x = x.view(n, self.in_channels, -1) | |
elif self.mode == 'concatenation': | |
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) | |
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) | |
else: | |
theta_x = self.theta(x).view(n, self.inter_channels, -1) | |
theta_x = theta_x.permute(0, 2, 1) | |
phi_x = self.phi(x).view(n, self.inter_channels, -1) | |
pairwise_func = getattr(self, self.mode) | |
# NonLocal1d pairwise_weight: [N, H, H] | |
# NonLocal2d pairwise_weight: [N, HxW, HxW] | |
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] | |
pairwise_weight = pairwise_func(theta_x, phi_x) | |
# NonLocal1d y: [N, H, C] | |
# NonLocal2d y: [N, HxW, C] | |
# NonLocal3d y: [N, TxHxW, C] | |
y = torch.matmul(pairwise_weight, g_x) | |
# NonLocal1d y: [N, C, H] | |
# NonLocal2d y: [N, C, H, W] | |
# NonLocal3d y: [N, C, T, H, W] | |
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, | |
*x.size()[2:]) | |
output = x + self.conv_out(y) | |
return output | |
class NonLocal1d(_NonLocalNd): | |
"""1D Non-local module. | |
Args: | |
in_channels (int): Same as `NonLocalND`. | |
sub_sample (bool): Whether to apply max pooling after pairwise | |
function (Note that the `sub_sample` is applied on spatial only). | |
Default: False. | |
conv_cfg (None | dict): Same as `NonLocalND`. | |
Default: dict(type='Conv1d'). | |
""" | |
def __init__(self, | |
in_channels: int, | |
sub_sample: bool = False, | |
conv_cfg: Dict = dict(type='Conv1d'), | |
**kwargs): | |
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) | |
self.sub_sample = sub_sample | |
if sub_sample: | |
max_pool_layer = nn.MaxPool1d(kernel_size=2) | |
self.g = nn.Sequential(self.g, max_pool_layer) | |
if self.mode != 'gaussian': | |
self.phi = nn.Sequential(self.phi, max_pool_layer) | |
else: | |
self.phi = max_pool_layer | |
class NonLocal2d(_NonLocalNd): | |
"""2D Non-local module. | |
Args: | |
in_channels (int): Same as `NonLocalND`. | |
sub_sample (bool): Whether to apply max pooling after pairwise | |
function (Note that the `sub_sample` is applied on spatial only). | |
Default: False. | |
conv_cfg (None | dict): Same as `NonLocalND`. | |
Default: dict(type='Conv2d'). | |
""" | |
_abbr_ = 'nonlocal_block' | |
def __init__(self, | |
in_channels: int, | |
sub_sample: bool = False, | |
conv_cfg: Dict = dict(type='Conv2d'), | |
**kwargs): | |
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) | |
self.sub_sample = sub_sample | |
if sub_sample: | |
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) | |
self.g = nn.Sequential(self.g, max_pool_layer) | |
if self.mode != 'gaussian': | |
self.phi = nn.Sequential(self.phi, max_pool_layer) | |
else: | |
self.phi = max_pool_layer | |
class NonLocal3d(_NonLocalNd): | |
"""3D Non-local module. | |
Args: | |
in_channels (int): Same as `NonLocalND`. | |
sub_sample (bool): Whether to apply max pooling after pairwise | |
function (Note that the `sub_sample` is applied on spatial only). | |
Default: False. | |
conv_cfg (None | dict): Same as `NonLocalND`. | |
Default: dict(type='Conv3d'). | |
""" | |
def __init__(self, | |
in_channels: int, | |
sub_sample: bool = False, | |
conv_cfg: Dict = dict(type='Conv3d'), | |
**kwargs): | |
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) | |
self.sub_sample = sub_sample | |
if sub_sample: | |
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) | |
self.g = nn.Sequential(self.g, max_pool_layer) | |
if self.mode != 'gaussian': | |
self.phi = nn.Sequential(self.phi, max_pool_layer) | |
else: | |
self.phi = max_pool_layer | |