|
|
|
import torch |
|
from mmengine.registry import MODELS |
|
from torch import Tensor |
|
from torch import distributed as dist |
|
from torch import nn as nn |
|
from torch.autograd.function import Function |
|
|
|
|
|
class AllReduce(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, input: Tensor) -> Tensor: |
|
input_list = [ |
|
torch.zeros_like(input) for k in range(dist.get_world_size()) |
|
] |
|
|
|
dist.all_gather(input_list, input, async_op=False) |
|
inputs = torch.stack(input_list, dim=0) |
|
return torch.sum(inputs, dim=0) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output: Tensor) -> Tensor: |
|
dist.all_reduce(grad_output, async_op=False) |
|
return grad_output |
|
|
|
|
|
@MODELS.register_module('naiveSyncBN1d') |
|
class NaiveSyncBatchNorm1d(nn.BatchNorm1d): |
|
"""Synchronized Batch Normalization for 3D Tensors. |
|
|
|
Note: |
|
This implementation is modified from |
|
https://github.com/facebookresearch/detectron2/ |
|
|
|
`torch.nn.SyncBatchNorm` has known unknown bugs. |
|
It produces significantly worse AP (and sometimes goes NaN) |
|
when the batch size on each worker is quite different |
|
(e.g., when scale augmentation is used). |
|
In 3D detection, different workers has points of different shapes, |
|
which also cause instability. |
|
|
|
Use this implementation before `nn.SyncBatchNorm` is fixed. |
|
It is slower than `nn.SyncBatchNorm`. |
|
""" |
|
|
|
def __init__(self, *args: list, **kwargs: dict) -> None: |
|
super(NaiveSyncBatchNorm1d, self).__init__(*args, **kwargs) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
""" |
|
Args: |
|
input (Tensor): Has shape (N, C) or (N, C, L), where N is |
|
the batch size, C is the number of features or |
|
channels, and L is the sequence length |
|
|
|
Returns: |
|
Tensor: Has shape (N, C) or (N, C, L), same shape as input. |
|
""" |
|
using_dist = dist.is_available() and dist.is_initialized() |
|
if (not using_dist) or dist.get_world_size() == 1 \ |
|
or not self.training: |
|
return super().forward(input) |
|
assert input.shape[0] > 0, 'SyncBN does not support empty inputs' |
|
is_two_dim = input.dim() == 2 |
|
if is_two_dim: |
|
input = input.unsqueeze(2) |
|
|
|
C = input.shape[1] |
|
mean = torch.mean(input, dim=[0, 2]) |
|
meansqr = torch.mean(input * input, dim=[0, 2]) |
|
|
|
vec = torch.cat([mean, meansqr], dim=0) |
|
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) |
|
|
|
mean, meansqr = torch.split(vec, C) |
|
var = meansqr - mean * mean |
|
self.running_mean += self.momentum * ( |
|
mean.detach() - self.running_mean) |
|
self.running_var += self.momentum * (var.detach() - self.running_var) |
|
|
|
invstd = torch.rsqrt(var + self.eps) |
|
scale = self.weight * invstd |
|
bias = self.bias - mean * scale |
|
scale = scale.reshape(1, -1, 1) |
|
bias = bias.reshape(1, -1, 1) |
|
output = input * scale + bias |
|
if is_two_dim: |
|
output = output.squeeze(2) |
|
return output |
|
|
|
|
|
@MODELS.register_module('naiveSyncBN2d') |
|
class NaiveSyncBatchNorm2d(nn.BatchNorm2d): |
|
"""Synchronized Batch Normalization for 4D Tensors. |
|
|
|
Note: |
|
This implementation is modified from |
|
https://github.com/facebookresearch/detectron2/ |
|
|
|
`torch.nn.SyncBatchNorm` has known unknown bugs. |
|
It produces significantly worse AP (and sometimes goes NaN) |
|
when the batch size on each worker is quite different |
|
(e.g., when scale augmentation is used). |
|
This phenomenon also occurs when the multi-modality feature fusion |
|
modules of multi-modality detectors use SyncBN. |
|
|
|
Use this implementation before `nn.SyncBatchNorm` is fixed. |
|
It is slower than `nn.SyncBatchNorm`. |
|
""" |
|
|
|
def __init__(self, *args: list, **kwargs: dict) -> None: |
|
super(NaiveSyncBatchNorm2d, self).__init__(*args, **kwargs) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
""" |
|
Args: |
|
Input (Tensor): Feature has shape (N, C, H, W). |
|
|
|
Returns: |
|
Tensor: Has shape (N, C, H, W), same shape as input. |
|
""" |
|
assert input.dtype == torch.float32, \ |
|
f'input should be in float32 type, got {input.dtype}' |
|
using_dist = dist.is_available() and dist.is_initialized() |
|
if (not using_dist) or \ |
|
dist.get_world_size() == 1 or \ |
|
not self.training: |
|
return super().forward(input) |
|
|
|
assert input.shape[0] > 0, 'SyncBN does not support empty inputs' |
|
C = input.shape[1] |
|
mean = torch.mean(input, dim=[0, 2, 3]) |
|
meansqr = torch.mean(input * input, dim=[0, 2, 3]) |
|
|
|
vec = torch.cat([mean, meansqr], dim=0) |
|
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) |
|
|
|
mean, meansqr = torch.split(vec, C) |
|
var = meansqr - mean * mean |
|
self.running_mean += self.momentum * ( |
|
mean.detach() - self.running_mean) |
|
self.running_var += self.momentum * (var.detach() - self.running_var) |
|
|
|
invstd = torch.rsqrt(var + self.eps) |
|
scale = self.weight * invstd |
|
bias = self.bias - mean * scale |
|
scale = scale.reshape(1, -1, 1, 1) |
|
bias = bias.reshape(1, -1, 1, 1) |
|
return input * scale + bias |
|
|