import copy |
import torch.nn as nn |
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init |
from mmcv.utils.parrots_wrapper import _BatchNorm |
from mmpose.core import WeightNormClipHook |
from ..builder import BACKBONES |
from .base_backbone import BaseBackbone |
class BasicTemporalBlock(nn.Module): |
"""Basic block for VideoPose3D. |
Args: |
in_channels (int): Input channels of this block. |
out_channels (int): Output channels of this block. |
mid_channels (int): The output channels of conv1. Default: 1024. |
kernel_size (int): Size of the convolving kernel. Default: 3. |
dilation (int): Spacing between kernel elements. Default: 3. |
dropout (float): Dropout rate. Default: 0.25. |
causal (bool): Use causal convolutions instead of symmetric |
convolutions (for real-time applications). Default: False. |
residual (bool): Use residual connection. Default: True. |
use_stride_conv (bool): Use optimized TCN that designed |
specifically for single-frame batching, i.e. where batches have |
input length = receptive field, and output length = 1. This |
implementation replaces dilated convolutions with strided |
convolutions to avoid generating unused intermediate results. |
Default: False. |
conv_cfg (dict): dictionary to construct and config conv layer. |
Default: dict(type='Conv1d'). |
norm_cfg (dict): dictionary to construct and config norm layer. |
Default: dict(type='BN1d'). |
""" |
def __init__(self, |
in_channels, |
out_channels, |
mid_channels=1024, |
kernel_size=3, |
dilation=3, |
dropout=0.25, |
causal=False, |
residual=True, |
use_stride_conv=False, |
conv_cfg=dict(type='Conv1d'), |
norm_cfg=dict(type='BN1d')): |
conv_cfg = copy.deepcopy(conv_cfg) |
norm_cfg = copy.deepcopy(norm_cfg) |
super().__init__() |
self.in_channels = in_channels |
self.out_channels = out_channels |
self.mid_channels = mid_channels |
self.kernel_size = kernel_size |
self.dilation = dilation |
self.dropout = dropout |
self.causal = causal |
self.residual = residual |
self.use_stride_conv = use_stride_conv |
self.pad = (kernel_size - 1) * dilation // 2 |
if use_stride_conv: |
self.stride = kernel_size |
self.causal_shift = kernel_size // 2 if causal else 0 |
self.dilation = 1 |
else: |
self.stride = 1 |
self.causal_shift = kernel_size // 2 * dilation if causal else 0 |
self.conv1 = nn.Sequential( |
ConvModule( |
in_channels, |
mid_channels, |
kernel_size=kernel_size, |
stride=self.stride, |
dilation=self.dilation, |
bias='auto', |
conv_cfg=conv_cfg, |
norm_cfg=norm_cfg)) |
self.conv2 = nn.Sequential( |
ConvModule( |
mid_channels, |
out_channels, |
kernel_size=1, |
bias='auto', |
conv_cfg=conv_cfg, |
norm_cfg=norm_cfg)) |
if residual and in_channels != out_channels: |
self.short_cut = build_conv_layer(conv_cfg, in_channels, |
out_channels, 1) |
else: |
self.short_cut = None |
self.dropout = nn.Dropout(dropout) if dropout > 0 else None |
def forward(self, x): |
"""Forward function.""" |
if self.use_stride_conv: |
assert self.causal_shift + self.kernel_size // 2 < x.shape[2] |
else: |
assert 0 <= self.pad + self.causal_shift < x.shape[2] - \ |
self.pad + self.causal_shift <= x.shape[2] |
out = self.conv1(x) |
if self.dropout is not None: |
out = self.dropout(out) |
out = self.conv2(out) |
if self.dropout is not None: |
out = self.dropout(out) |
if self.residual: |
if self.use_stride_conv: |
res = x[:, :, self.causal_shift + |
self.kernel_size // 2::self.kernel_size] |
else: |
res = x[:, :, |
(self.pad + self.causal_shift):(x.shape[2] - self.pad + |
self.causal_shift)] |
if self.short_cut is not None: |
res = self.short_cut(res) |
out = out + res |
return out |
@BACKBONES.register_module() |
class TCN(BaseBackbone): |
"""TCN backbone. |
Temporal Convolutional Networks. |
More details can be found in the |
`paper <https://arxiv.org/abs/1811.11742>`__ . |
Args: |
in_channels (int): Number of input channels, which equals to |
num_keypoints * num_features. |
stem_channels (int): Number of feature channels. Default: 1024. |
num_blocks (int): NUmber of basic temporal convolutional blocks. |
Default: 2. |
kernel_sizes (Sequence[int]): Sizes of the convolving kernel of |
each basic block. Default: ``(3, 3, 3)``. |
dropout (float): Dropout rate. Default: 0.25. |
causal (bool): Use causal convolutions instead of symmetric |
convolutions (for real-time applications). |
Default: False. |
residual (bool): Use residual connection. Default: True. |
use_stride_conv (bool): Use TCN backbone optimized for |
single-frame batching, i.e. where batches have input length = |
receptive field, and output length = 1. This implementation |
replaces dilated convolutions with strided convolutions to avoid |
generating unused intermediate results. The weights are |
interchangeable with the reference implementation. Default: False |
conv_cfg (dict): dictionary to construct and config conv layer. |
Default: dict(type='Conv1d'). |
norm_cfg (dict): dictionary to construct and config norm layer. |
Default: dict(type='BN1d'). |
max_norm (float|None): if not None, the weight of convolution layers |
will be clipped to have a maximum norm of max_norm. |
Example: |
>>> from mmpose.models import TCN |
>>> import torch |
>>> self = TCN(in_channels=34) |
>>> self.eval() |
>>> inputs = torch.rand(1, 34, 243) |
>>> level_outputs = self.forward(inputs) |
>>> for level_out in level_outputs: |
... print(tuple(level_out.shape)) |
(1, 1024, 235) |
(1, 1024, 217) |
""" |
def __init__(self, |
in_channels, |
stem_channels=1024, |
num_blocks=2, |
kernel_sizes=(3, 3, 3), |
dropout=0.25, |
causal=False, |
residual=True, |
use_stride_conv=False, |
conv_cfg=dict(type='Conv1d'), |
norm_cfg=dict(type='BN1d'), |
max_norm=None): |
conv_cfg = copy.deepcopy(conv_cfg) |
norm_cfg = copy.deepcopy(norm_cfg) |
super().__init__() |
self.in_channels = in_channels |
self.stem_channels = stem_channels |
self.num_blocks = num_blocks |
self.kernel_sizes = kernel_sizes |
self.dropout = dropout |
self.causal = causal |
self.residual = residual |
self.use_stride_conv = use_stride_conv |
self.max_norm = max_norm |
assert num_blocks == len(kernel_sizes) - 1 |
for ks in kernel_sizes: |
assert ks % 2 == 1, 'Only odd filter widths are supported.' |
self.expand_conv = ConvModule( |
in_channels, |
stem_channels, |
kernel_size=kernel_sizes[0], |
stride=kernel_sizes[0] if use_stride_conv else 1, |
bias='auto', |
conv_cfg=conv_cfg, |
norm_cfg=norm_cfg) |
dilation = kernel_sizes[0] |
self.tcn_blocks = nn.ModuleList() |
for i in range(1, num_blocks + 1): |
self.tcn_blocks.append( |
BasicTemporalBlock( |
in_channels=stem_channels, |
out_channels=stem_channels, |
mid_channels=stem_channels, |
kernel_size=kernel_sizes[i], |
dilation=dilation, |
dropout=dropout, |
causal=causal, |
residual=residual, |
use_stride_conv=use_stride_conv, |
conv_cfg=conv_cfg, |
norm_cfg=norm_cfg)) |
dilation *= kernel_sizes[i] |
if self.max_norm is not None: |
weight_clip = WeightNormClipHook(self.max_norm) |
for module in self.modules(): |
if isinstance(module, nn.modules.conv._ConvNd): |
weight_clip.register(module) |
self.dropout = nn.Dropout(dropout) if dropout > 0 else None |
def forward(self, x): |
"""Forward function.""" |
x = self.expand_conv(x) |
if self.dropout is not None: |
x = self.dropout(x) |
outs = [] |
for i in range(self.num_blocks): |
x = self.tcn_blocks[i](x) |
outs.append(x) |
return tuple(outs) |
def init_weights(self, pretrained=None): |
"""Initialize the weights.""" |
super().init_weights(pretrained) |
if pretrained is None: |
for m in self.modules(): |
if isinstance(m, nn.modules.conv._ConvNd): |
kaiming_init(m, mode='fan_in', nonlinearity='relu') |
elif isinstance(m, _BatchNorm): |
constant_init(m, 1) |