# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from typing import Optional, Sequence, Tuple, Union import torch.nn as nn import torch.utils.checkpoint as cp from mmengine.model import constant_init, kaiming_init from mmengine.runner import load_checkpoint from torch import Tensor def conv3x3(in_planes: int, out_planes: int, stride: int = 1, dilation: int = 1): """3x3 convolution with padding.""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes: int, planes: int, stride: int = 1, dilation: int = 1, downsample: Optional[nn.Module] = None, style: str = 'pytorch', with_cp: bool = False): super().__init__() assert style in ['pytorch', 'caffe'] self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride self.dilation = dilation assert not with_cp def forward(self, x: Tensor) -> Tensor: residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes: int, planes: int, stride: int = 1, dilation: int = 1, downsample: Optional[nn.Module] = None, style: str = 'pytorch', with_cp: bool = False): """Bottleneck block. If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is "caffe", the stride-two layer is the first 1x1 conv layer. """ super().__init__() assert style in ['pytorch', 'caffe'] if style == 'pytorch': conv1_stride = 1 conv2_stride = stride else: conv1_stride = stride conv2_stride = 1 self.conv1 = nn.Conv2d( inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=conv2_stride, padding=dilation, dilation=dilation, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d( planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation self.with_cp = with_cp def forward(self, x: Tensor) -> Tensor: def _inner_forward(x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out def make_res_layer(block: nn.Module, inplanes: int, planes: int, blocks: int, stride: int = 1, dilation: int = 1, style: str = 'pytorch', with_cp: bool = False) -> nn.Module: downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append( block( inplanes, planes, stride, dilation, downsample, style=style, with_cp=with_cp)) inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp)) return nn.Sequential(*layers) class ResNet(nn.Module): """ResNet backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. num_stages (int): Resnet stages, normally 4. strides (Sequence[int]): Strides of the first block of each stage. dilations (Sequence[int]): Dilation of each stage. out_indices (Sequence[int]): Output from which stages. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze running stats (mean and var). bn_frozen (bool): Whether to freeze weight and bias of BN layers. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. """ arch_settings = { 18: (BasicBlock, (2, 2, 2, 2)), 34: (BasicBlock, (3, 4, 6, 3)), 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)) } def __init__(self, depth: int, num_stages: int = 4, strides: Sequence[int] = (1, 2, 2, 2), dilations: Sequence[int] = (1, 1, 1, 1), out_indices: Sequence[int] = (0, 1, 2, 3), style: str = 'pytorch', frozen_stages: int = -1, bn_eval: bool = True, bn_frozen: bool = False, with_cp: bool = False): super().__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') assert num_stages >= 1 and num_stages <= 4 block, stage_blocks = self.arch_settings[depth] stage_blocks = stage_blocks[:num_stages] # type: ignore assert len(strides) == len(dilations) == num_stages assert max(out_indices) < num_stages self.out_indices = out_indices self.style = style self.frozen_stages = frozen_stages self.bn_eval = bn_eval self.bn_frozen = bn_frozen self.with_cp = with_cp self.inplanes: int = 64 self.conv1 = nn.Conv2d( 3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.res_layers = [] for i, num_blocks in enumerate(stage_blocks): stride = strides[i] dilation = dilations[i] planes = 64 * 2**i res_layer = make_res_layer( block, self.inplanes, planes, num_blocks, stride=stride, dilation=dilation, style=self.style, with_cp=with_cp) self.inplanes = planes * block.expansion # type: ignore layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = block.expansion * 64 * 2**( # type: ignore len(stage_blocks) - 1) def init_weights(self, pretrained: Optional[str] = None) -> None: if isinstance(pretrained, str): logger = logging.getLogger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) else: raise TypeError('pretrained must be a str or None') def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] else: return tuple(outs) def train(self, mode: bool = True) -> None: super().train(mode) if self.bn_eval: for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() if self.bn_frozen: for params in m.parameters(): params.requires_grad = False if mode and self.frozen_stages >= 0: for param in self.conv1.parameters(): param.requires_grad = False for param in self.bn1.parameters(): param.requires_grad = False self.bn1.eval() self.bn1.weight.requires_grad = False self.bn1.bias.requires_grad = False for i in range(1, self.frozen_stages + 1): mod = getattr(self, f'layer{i}') mod.eval() for param in mod.parameters(): param.requires_grad = False