Spaces:
Runtime error
Runtime error
# Modified from detectron2: https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13 | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class FrozenBatchNorm2d(nn.Module): | |
""" | |
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
It contains non-trainable buffers called | |
"weight" and "bias", "running_mean", "running_var", | |
initialized to perform identity transformation. | |
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | |
which are computed from the original four parameters of BN. | |
The affine transform `x * weight + bias` will perform the equivalent | |
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | |
When loading a backbone model from Caffe2, "running_mean" and "running_var" | |
will be left unchanged as identity transformation. | |
Other pre-trained backbone models may contain all 4 parameters. | |
The forward is implemented by `F.batch_norm(..., training=False)`. | |
""" | |
def __init__(self, num_features, eps=1e-5): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.register_buffer("weight", torch.ones(num_features)) | |
self.register_buffer("bias", torch.zeros(num_features)) | |
self.register_buffer("running_mean", torch.zeros(num_features)) | |
self.register_buffer("running_var", torch.ones(num_features) - eps) | |
def forward(self, x): | |
if x.requires_grad: | |
# When gradients are needed, F.batch_norm will use extra memory | |
# because its backward op computes gradients for weight/bias as well. | |
scale = self.weight * (self.running_var + self.eps).rsqrt() | |
bias = self.bias - self.running_mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
out_dtype = x.dtype # may be half | |
return x * scale.to(out_dtype) + bias.to(out_dtype) | |
else: | |
# When gradients are not needed, F.batch_norm is a single fused op | |
# and provide more optimization opportunities. | |
return F.batch_norm( | |
x, | |
self.running_mean, | |
self.running_var, | |
self.weight, | |
self.bias, | |
training=False, | |
eps=self.eps, | |
) | |
def _load_from_state_dict( | |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
): | |
num_batches_tracked_key = prefix + 'num_batches_tracked' | |
if num_batches_tracked_key in state_dict: | |
del state_dict[num_batches_tracked_key] | |
version = local_metadata.get("version", None) | |
if version is None or version < 2: | |
# No running_mean/var in early versions | |
# This will silent the warnings | |
if prefix + "running_mean" not in state_dict: | |
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | |
if prefix + "running_var" not in state_dict: | |
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | |
super()._load_from_state_dict( | |
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
) | |
def __repr__(self): | |
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) | |