mart9992's picture
m
2cd560a
raw
history blame
12.5 kB
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
normal_init)
from mmcv.utils import digit_version
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.utils.ops import resize
from ..backbones.resnet import BasicBlock, Bottleneck
from ..builder import NECKS
try:
from mmcv.ops import DeformConv2d
has_mmcv_full = True
except (ImportError, ModuleNotFoundError):
has_mmcv_full = False
@NECKS.register_module()
class PoseWarperNeck(nn.Module):
"""PoseWarper neck.
`"Learning temporal pose estimation from sparsely-labeled videos"
<https://arxiv.org/abs/1906.04016>`_.
Args:
in_channels (int): Number of input channels from backbone
out_channels (int): Number of output channels
inner_channels (int): Number of intermediate channels of the res block
deform_groups (int): Number of groups in the deformable conv
dilations (list|tuple): different dilations of the offset conv layers
trans_conv_kernel (int): the kernel of the trans conv layer, which is
used to get heatmap from the output of backbone. Default: 1
res_blocks_cfg (dict|None): config of residual blocks. If None,
use the default values. If not None, it should contain the
following keys:
- block (str): the type of residual block, Default: 'BASIC'.
- num_blocks (int): the number of blocks, Default: 20.
offsets_kernel (int): the kernel of offset conv layer.
deform_conv_kernel (int): the kernel of defomrable conv layer.
in_index (int|Sequence[int]): Input feature index. Default: 0
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
Default: None.
- 'resize_concat': Multiple feature maps will be resize to \
the same size as first one and than concat together. \
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into \
a list and passed into decode head.
- None: Only one select feature map is allowed.
freeze_trans_layer (bool): Whether to freeze the transition layer
(stop grad and set eval mode). Default: True.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
im2col_step (int): the argument `im2col_step` in deformable conv,
Default: 80.
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
minimum_mmcv_version = '1.3.17'
def __init__(self,
in_channels,
out_channels,
inner_channels,
deform_groups=17,
dilations=(3, 6, 12, 18, 24),
trans_conv_kernel=1,
res_blocks_cfg=None,
offsets_kernel=3,
deform_conv_kernel=3,
in_index=0,
input_transform=None,
freeze_trans_layer=True,
norm_eval=False,
im2col_step=80):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.inner_channels = inner_channels
self.deform_groups = deform_groups
self.dilations = dilations
self.trans_conv_kernel = trans_conv_kernel
self.res_blocks_cfg = res_blocks_cfg
self.offsets_kernel = offsets_kernel
self.deform_conv_kernel = deform_conv_kernel
self.in_index = in_index
self.input_transform = input_transform
self.freeze_trans_layer = freeze_trans_layer
self.norm_eval = norm_eval
self.im2col_step = im2col_step
identity_trans_layer = False
assert trans_conv_kernel in [0, 1, 3]
kernel_size = trans_conv_kernel
if kernel_size == 3:
padding = 1
elif kernel_size == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_trans_layer = True
if identity_trans_layer:
self.trans_layer = nn.Identity()
else:
self.trans_layer = build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding)
# build chain of residual blocks
if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict):
raise TypeError('res_blocks_cfg should be dict or None.')
if res_blocks_cfg is None:
block_type = 'BASIC'
num_blocks = 20
else:
block_type = res_blocks_cfg.get('block', 'BASIC')
num_blocks = res_blocks_cfg.get('num_blocks', 20)
block = self.blocks_dict[block_type]
res_layers = []
downsample = nn.Sequential(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=out_channels,
out_channels=inner_channels,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(dict(type='BN'), inner_channels)[1])
res_layers.append(
block(
in_channels=out_channels,
out_channels=inner_channels,
downsample=downsample))
for _ in range(1, num_blocks):
res_layers.append(block(inner_channels, inner_channels))
self.offset_feats = nn.Sequential(*res_layers)
# build offset layers
self.num_offset_layers = len(dilations)
assert self.num_offset_layers > 0, 'Number of offset layers ' \
'should be larger than 0.'
target_offset_channels = 2 * offsets_kernel**2 * deform_groups
offset_layers = [
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=inner_channels,
out_channels=target_offset_channels,
kernel_size=offsets_kernel,
stride=1,
dilation=dilations[i],
padding=dilations[i],
bias=False,
) for i in range(self.num_offset_layers)
]
self.offset_layers = nn.ModuleList(offset_layers)
# build deformable conv layers
assert digit_version(mmcv.__version__) >= \
digit_version(self.minimum_mmcv_version), \
f'Current MMCV version: {mmcv.__version__}, ' \
f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \
f'https://github.com/open-mmlab/mmcv/issues/1440, ' \
f'Please install the latest MMCV.'
if has_mmcv_full:
deform_conv_layers = [
DeformConv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=deform_conv_kernel,
stride=1,
padding=int(deform_conv_kernel / 2) * dilations[i],
dilation=dilations[i],
deform_groups=deform_groups,
im2col_step=self.im2col_step,
) for i in range(self.num_offset_layers)
]
else:
raise ImportError('Please install the full version of mmcv '
'to use `DeformConv2d`.')
self.deform_conv_layers = nn.ModuleList(deform_conv_layers)
self.freeze_layers()
def freeze_layers(self):
if self.freeze_trans_layer:
self.trans_layer.eval()
for param in self.trans_layer.parameters():
param.requires_grad = False
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, DeformConv2d):
filler = torch.zeros([
m.weight.size(0),
m.weight.size(1),
m.weight.size(2),
m.weight.size(3)
],
dtype=torch.float32,
device=m.weight.device)
for k in range(m.weight.size(0)):
filler[k, k,
int(m.weight.size(2) / 2),
int(m.weight.size(3) / 2)] = 1.0
m.weight = torch.nn.Parameter(filler)
m.weight.requires_grad = True
# posewarper offset layer weight initialization
for m in self.offset_layers.modules():
constant_init(m, 0)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor] | Tensor): multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(inputs, list):
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, inputs, frame_weight):
assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \
'should be list or tuple, even though the length is 1, ' \
'for unified processing.'
output_heatmap = 0
if len(inputs) > 1:
inputs = [self._transform_inputs(input) for input in inputs]
inputs = [self.trans_layer(input) for input in inputs]
# calculate difference features
diff_features = [
self.offset_feats(inputs[0] - input) for input in inputs
]
for i in range(len(inputs)):
if frame_weight[i] == 0:
continue
warped_heatmap = 0
for j in range(self.num_offset_layers):
offset = (self.offset_layers[j](diff_features[i]))
warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i],
offset)
warped_heatmap += warped_heatmap_tmp / \
self.num_offset_layers
output_heatmap += warped_heatmap * frame_weight[i]
else:
inputs = inputs[0]
inputs = self._transform_inputs(inputs)
inputs = self.trans_layer(inputs)
num_frames = len(frame_weight)
batch_size = inputs.size(0) // num_frames
ref_x = inputs[:batch_size]
ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1)
offset_features = self.offset_feats(ref_x_tiled - inputs)
warped_heatmap = 0
for j in range(self.num_offset_layers):
offset = self.offset_layers[j](offset_features)
warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset)
warped_heatmap += warped_heatmap_tmp / self.num_offset_layers
for i in range(num_frames):
if frame_weight[i] == 0:
continue
output_heatmap += warped_heatmap[i * batch_size:(i + 1) *
batch_size] * frame_weight[i]
return output_heatmap
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
self.freeze_layers()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()