mart9992's picture
m
2cd560a
raw
history blame
24 kB
# --------------------------------------------------------
# High Resolution Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Rao Fu, RainbowSecret
# --------------------------------------------------------
import pdb
import torch
import torch.nn as nn
from mmcv.cnn import (
build_conv_layer,
build_norm_layer,
constant_init,
kaiming_init,
normal_init,
)
# from mmcv.runner import load_checkpoint
from .hrt_checkpoint import load_checkpoint
from mmcv.runner.checkpoint import load_state_dict
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmpose.models.utils.ops import resize
from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .modules.bottleneck_block import Bottleneck
from .modules.transformer_block import GeneralTransformerBlock
class HighResolutionTransformerModule(nn.Module):
def __init__(
self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type="BN", requires_grad=True),
num_heads=None,
num_window_sizes=None,
num_mlp_ratios=None,
drop_paths=0.0,
):
super(HighResolutionTransformerModule, self).__init__()
self._check_branches(num_branches, num_blocks, in_channels, num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(
num_branches,
blocks,
num_blocks,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
drop_paths,
)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=True)
# MHSA parameters
self.num_heads = num_heads
self.num_window_sizes = num_window_sizes
self.num_mlp_ratios = num_mlp_ratios
def _check_branches(self, num_branches, num_blocks, in_channels, num_channels):
logger = get_root_logger()
if num_branches != len(num_blocks):
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
num_branches, len(num_blocks)
)
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
num_branches, len(num_channels)
)
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = "NUM_BRANCHES({}) <> IN_CHANNELS({})".format(
num_branches, len(in_channels)
)
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(
self,
branch_index,
block,
num_blocks,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
drop_paths,
stride=1,
):
"""Make one branch."""
downsample = None
if (
stride != 1
or self.in_channels[branch_index]
!= num_channels[branch_index] * block.expansion
):
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
build_norm_layer(
self.norm_cfg, num_channels[branch_index] * block.expansion
)[1],
)
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
num_heads=num_heads[branch_index],
window_size=num_window_sizes[branch_index],
mlp_ratio=num_mlp_ratios[branch_index],
drop_path=drop_paths[0],
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
)
)
self.in_channels[branch_index] = num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
num_heads=num_heads[branch_index],
window_size=num_window_sizes[branch_index],
mlp_ratio=num_mlp_ratios[branch_index],
drop_path=drop_paths[i],
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
)
)
return nn.Sequential(*layers)
def _make_branches(
self,
num_branches,
block,
num_blocks,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
drop_paths,
):
"""Make branches."""
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(
i,
block,
num_blocks,
num_channels,
num_heads,
num_window_sizes,
num_mlp_ratios,
drop_paths,
)
)
return nn.ModuleList(branches)
def _make_fuse_layers(self):
"""Build fuse layer."""
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(
scale_factor=2 ** (j - i),
mode="bilinear",
align_corners=False,
),
)
)
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False,
),
build_norm_layer(self.norm_cfg, in_channels[j])[1],
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
bias=False,
),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
)
)
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
groups=in_channels[j],
bias=False,
),
build_norm_layer(self.norm_cfg, in_channels[j])[1],
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=1,
stride=1,
bias=False,
),
build_norm_layer(self.norm_cfg, in_channels[j])[1],
nn.ReLU(inplace=True),
)
)
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
"""Forward function."""
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y += x[j]
elif j > i:
y = y + resize(
self.fuse_layers[i][j](x[j]),
size=x[i].shape[2:],
mode="bilinear",
align_corners=False,
)
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module()
class HRT(nn.Module):
"""HRT backbone.
High Resolution Transformer Backbone
"""
blocks_dict = {
"BOTTLENECK": Bottleneck,
"TRANSFORMER_BLOCK": GeneralTransformerBlock,
}
def __init__(
self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type="BN", requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=False,
):
super(HRT, self).__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False
)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# generat drop path rate list
depth_s2 = (
self.extra["stage2"]["num_blocks"][0] * self.extra["stage2"]["num_modules"]
)
depth_s3 = (
self.extra["stage3"]["num_blocks"][0] * self.extra["stage3"]["num_modules"]
)
depth_s4 = (
self.extra["stage4"]["num_blocks"][0] * self.extra["stage4"]["num_modules"]
)
depths = [depth_s2, depth_s3, depth_s4]
drop_path_rate = self.extra["drop_path_rate"]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
logger = get_root_logger()
logger.info(dpr)
# stage 1
self.stage1_cfg = self.extra["stage1"]
num_channels = self.stage1_cfg["num_channels"][0]
block_type = self.stage1_cfg["block"]
num_blocks = self.stage1_cfg["num_blocks"][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra["stage2"]
num_channels = self.stage2_cfg["num_channels"]
block_type = self.stage2_cfg["block"]
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer(
[stage1_out_channels], num_channels
)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels, drop_paths=dpr[0:depth_s2]
)
# stage 3
self.stage3_cfg = self.extra["stage3"]
num_channels = self.stage3_cfg["num_channels"]
block_type = self.stage3_cfg["block"]
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg,
num_channels,
drop_paths=dpr[depth_s2 : depth_s2 + depth_s3],
)
# stage 4
self.stage4_cfg = self.extra["stage4"]
num_channels = self.stage4_cfg["num_channels"]
block_type = self.stage4_cfg["block"]
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg,
num_channels,
multiscale_output=self.stage4_cfg.get("multiscale_output", True),
drop_paths=dpr[depth_s2 + depth_s3 :],
)
@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
return getattr(self, self.norm1_name)
@property
def norm2(self):
"""nn.Module: the normalization layer named "norm2" """
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
"""Make transition layer."""
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[
1
],
nn.ReLU(inplace=True),
)
)
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = (
num_channels_cur_layer[i]
if j == i - num_branches_pre
else in_channels
)
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True),
)
)
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(
self,
block,
inplanes,
planes,
blocks,
stride=1,
num_heads=1,
window_size=7,
mlp_ratio=4.0,
):
"""Make each layer."""
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1],
)
layers = []
if isinstance(block, GeneralTransformerBlock):
layers.append(
block(
inplanes,
planes,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
)
)
else:
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
)
)
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
)
)
return nn.Sequential(*layers)
def _make_stage(
self, layer_config, in_channels, multiscale_output=True, drop_paths=0.0
):
"""Make each stage."""
num_modules = layer_config["num_modules"]
num_branches = layer_config["num_branches"]
num_blocks = layer_config["num_blocks"]
num_channels = layer_config["num_channels"]
block = self.blocks_dict[layer_config["block"]]
num_heads = layer_config["num_heads"]
num_window_sizes = layer_config["num_window_sizes"]
num_mlp_ratios = layer_config["num_mlp_ratios"]
hr_modules = []
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HighResolutionTransformerModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg,
num_heads=num_heads,
num_window_sizes=num_window_sizes,
num_mlp_ratios=num_mlp_ratios,
drop_paths=drop_paths[num_blocks[0] * i : num_blocks[0] * (i + 1)],
)
)
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
ckpt = load_checkpoint(self, pretrained, strict=False)
if "model" in ckpt:
msg = self.load_state_dict(ckpt["model"], strict=False)
logger.info(msg)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
"""mmseg: kaiming_init(m)"""
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError("pretrained must be a str or None")
def forward(self, x):
"""Forward function."""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg["num_branches"]):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg["num_branches"]):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg["num_branches"]):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
"""Convert the model into training mode."""
super(HRT, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()