|
|
|
import copy |
|
|
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init |
|
from mmcv.utils.parrots_wrapper import _BatchNorm |
|
|
|
from mmpose.core import WeightNormClipHook |
|
from ..builder import BACKBONES |
|
from .base_backbone import BaseBackbone |
|
|
|
|
|
class BasicTemporalBlock(nn.Module): |
|
"""Basic block for VideoPose3D. |
|
|
|
Args: |
|
in_channels (int): Input channels of this block. |
|
out_channels (int): Output channels of this block. |
|
mid_channels (int): The output channels of conv1. Default: 1024. |
|
kernel_size (int): Size of the convolving kernel. Default: 3. |
|
dilation (int): Spacing between kernel elements. Default: 3. |
|
dropout (float): Dropout rate. Default: 0.25. |
|
causal (bool): Use causal convolutions instead of symmetric |
|
convolutions (for real-time applications). Default: False. |
|
residual (bool): Use residual connection. Default: True. |
|
use_stride_conv (bool): Use optimized TCN that designed |
|
specifically for single-frame batching, i.e. where batches have |
|
input length = receptive field, and output length = 1. This |
|
implementation replaces dilated convolutions with strided |
|
convolutions to avoid generating unused intermediate results. |
|
Default: False. |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
Default: dict(type='Conv1d'). |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
Default: dict(type='BN1d'). |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
mid_channels=1024, |
|
kernel_size=3, |
|
dilation=3, |
|
dropout=0.25, |
|
causal=False, |
|
residual=True, |
|
use_stride_conv=False, |
|
conv_cfg=dict(type='Conv1d'), |
|
norm_cfg=dict(type='BN1d')): |
|
|
|
conv_cfg = copy.deepcopy(conv_cfg) |
|
norm_cfg = copy.deepcopy(norm_cfg) |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.mid_channels = mid_channels |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
self.dropout = dropout |
|
self.causal = causal |
|
self.residual = residual |
|
self.use_stride_conv = use_stride_conv |
|
|
|
self.pad = (kernel_size - 1) * dilation // 2 |
|
if use_stride_conv: |
|
self.stride = kernel_size |
|
self.causal_shift = kernel_size // 2 if causal else 0 |
|
self.dilation = 1 |
|
else: |
|
self.stride = 1 |
|
self.causal_shift = kernel_size // 2 * dilation if causal else 0 |
|
|
|
self.conv1 = nn.Sequential( |
|
ConvModule( |
|
in_channels, |
|
mid_channels, |
|
kernel_size=kernel_size, |
|
stride=self.stride, |
|
dilation=self.dilation, |
|
bias='auto', |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg)) |
|
self.conv2 = nn.Sequential( |
|
ConvModule( |
|
mid_channels, |
|
out_channels, |
|
kernel_size=1, |
|
bias='auto', |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg)) |
|
|
|
if residual and in_channels != out_channels: |
|
self.short_cut = build_conv_layer(conv_cfg, in_channels, |
|
out_channels, 1) |
|
else: |
|
self.short_cut = None |
|
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else None |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
if self.use_stride_conv: |
|
assert self.causal_shift + self.kernel_size // 2 < x.shape[2] |
|
else: |
|
assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ |
|
self.pad + self.causal_shift <= x.shape[2] |
|
|
|
out = self.conv1(x) |
|
if self.dropout is not None: |
|
out = self.dropout(out) |
|
|
|
out = self.conv2(out) |
|
if self.dropout is not None: |
|
out = self.dropout(out) |
|
|
|
if self.residual: |
|
if self.use_stride_conv: |
|
res = x[:, :, self.causal_shift + |
|
self.kernel_size // 2::self.kernel_size] |
|
else: |
|
res = x[:, :, |
|
(self.pad + self.causal_shift):(x.shape[2] - self.pad + |
|
self.causal_shift)] |
|
|
|
if self.short_cut is not None: |
|
res = self.short_cut(res) |
|
out = out + res |
|
|
|
return out |
|
|
|
|
|
@BACKBONES.register_module() |
|
class TCN(BaseBackbone): |
|
"""TCN backbone. |
|
|
|
Temporal Convolutional Networks. |
|
More details can be found in the |
|
`paper <https://arxiv.org/abs/1811.11742>`__ . |
|
|
|
Args: |
|
in_channels (int): Number of input channels, which equals to |
|
num_keypoints * num_features. |
|
stem_channels (int): Number of feature channels. Default: 1024. |
|
num_blocks (int): NUmber of basic temporal convolutional blocks. |
|
Default: 2. |
|
kernel_sizes (Sequence[int]): Sizes of the convolving kernel of |
|
each basic block. Default: ``(3, 3, 3)``. |
|
dropout (float): Dropout rate. Default: 0.25. |
|
causal (bool): Use causal convolutions instead of symmetric |
|
convolutions (for real-time applications). |
|
Default: False. |
|
residual (bool): Use residual connection. Default: True. |
|
use_stride_conv (bool): Use TCN backbone optimized for |
|
single-frame batching, i.e. where batches have input length = |
|
receptive field, and output length = 1. This implementation |
|
replaces dilated convolutions with strided convolutions to avoid |
|
generating unused intermediate results. The weights are |
|
interchangeable with the reference implementation. Default: False |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
Default: dict(type='Conv1d'). |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
Default: dict(type='BN1d'). |
|
max_norm (float|None): if not None, the weight of convolution layers |
|
will be clipped to have a maximum norm of max_norm. |
|
|
|
Example: |
|
>>> from mmpose.models import TCN |
|
>>> import torch |
|
>>> self = TCN(in_channels=34) |
|
>>> self.eval() |
|
>>> inputs = torch.rand(1, 34, 243) |
|
>>> level_outputs = self.forward(inputs) |
|
>>> for level_out in level_outputs: |
|
... print(tuple(level_out.shape)) |
|
(1, 1024, 235) |
|
(1, 1024, 217) |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
stem_channels=1024, |
|
num_blocks=2, |
|
kernel_sizes=(3, 3, 3), |
|
dropout=0.25, |
|
causal=False, |
|
residual=True, |
|
use_stride_conv=False, |
|
conv_cfg=dict(type='Conv1d'), |
|
norm_cfg=dict(type='BN1d'), |
|
max_norm=None): |
|
|
|
conv_cfg = copy.deepcopy(conv_cfg) |
|
norm_cfg = copy.deepcopy(norm_cfg) |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.stem_channels = stem_channels |
|
self.num_blocks = num_blocks |
|
self.kernel_sizes = kernel_sizes |
|
self.dropout = dropout |
|
self.causal = causal |
|
self.residual = residual |
|
self.use_stride_conv = use_stride_conv |
|
self.max_norm = max_norm |
|
|
|
assert num_blocks == len(kernel_sizes) - 1 |
|
for ks in kernel_sizes: |
|
assert ks % 2 == 1, 'Only odd filter widths are supported.' |
|
|
|
self.expand_conv = ConvModule( |
|
in_channels, |
|
stem_channels, |
|
kernel_size=kernel_sizes[0], |
|
stride=kernel_sizes[0] if use_stride_conv else 1, |
|
bias='auto', |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg) |
|
|
|
dilation = kernel_sizes[0] |
|
self.tcn_blocks = nn.ModuleList() |
|
for i in range(1, num_blocks + 1): |
|
self.tcn_blocks.append( |
|
BasicTemporalBlock( |
|
in_channels=stem_channels, |
|
out_channels=stem_channels, |
|
mid_channels=stem_channels, |
|
kernel_size=kernel_sizes[i], |
|
dilation=dilation, |
|
dropout=dropout, |
|
causal=causal, |
|
residual=residual, |
|
use_stride_conv=use_stride_conv, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg)) |
|
dilation *= kernel_sizes[i] |
|
|
|
if self.max_norm is not None: |
|
|
|
weight_clip = WeightNormClipHook(self.max_norm) |
|
for module in self.modules(): |
|
if isinstance(module, nn.modules.conv._ConvNd): |
|
weight_clip.register(module) |
|
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else None |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
x = self.expand_conv(x) |
|
|
|
if self.dropout is not None: |
|
x = self.dropout(x) |
|
|
|
outs = [] |
|
for i in range(self.num_blocks): |
|
x = self.tcn_blocks[i](x) |
|
outs.append(x) |
|
|
|
return tuple(outs) |
|
|
|
def init_weights(self, pretrained=None): |
|
"""Initialize the weights.""" |
|
super().init_weights(pretrained) |
|
if pretrained is None: |
|
for m in self.modules(): |
|
if isinstance(m, nn.modules.conv._ConvNd): |
|
kaiming_init(m, mode='fan_in', nonlinearity='relu') |
|
elif isinstance(m, _BatchNorm): |
|
constant_init(m, 1) |
|
|