# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch from torch import nn class FrozenBatchNorm2d(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed """ def __init__(self, n): super(FrozenBatchNorm2d, self).__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) def forward(self, x): # Cast all fixed parameters to half() if necessary if x.dtype == torch.float16: self.weight = self.weight.half() self.bias = self.bias.half() self.running_mean = self.running_mean.half() self.running_var = self.running_var.half() scale = self.weight * self.running_var.rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return x * scale + bias