File size: 10,003 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmpose.core import WeightNormClipHook
from ..builder import BACKBONES
from .base_backbone import BaseBackbone


class BasicTemporalBlock(nn.Module):
    """Basic block for VideoPose3D.

    Args:
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        mid_channels (int): The output channels of conv1. Default: 1024.
        kernel_size (int): Size of the convolving kernel. Default: 3.
        dilation (int): Spacing between kernel elements. Default: 3.
        dropout (float): Dropout rate. Default: 0.25.
        causal (bool): Use causal convolutions instead of symmetric
            convolutions (for real-time applications). Default: False.
        residual (bool): Use residual connection. Default: True.
        use_stride_conv (bool): Use optimized TCN that designed
            specifically for single-frame batching, i.e. where batches have
            input length = receptive field, and output length = 1. This
            implementation replaces dilated convolutions with strided
            convolutions to avoid generating unused intermediate results.
            Default: False.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: dict(type='Conv1d').
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN1d').
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=1024,
                 kernel_size=3,
                 dilation=3,
                 dropout=0.25,
                 causal=False,
                 residual=True,
                 use_stride_conv=False,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d')):
        # Protect mutable default arguments
        conv_cfg = copy.deepcopy(conv_cfg)
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.dropout = dropout
        self.causal = causal
        self.residual = residual
        self.use_stride_conv = use_stride_conv

        self.pad = (kernel_size - 1) * dilation // 2
        if use_stride_conv:
            self.stride = kernel_size
            self.causal_shift = kernel_size // 2 if causal else 0
            self.dilation = 1
        else:
            self.stride = 1
            self.causal_shift = kernel_size // 2 * dilation if causal else 0

        self.conv1 = nn.Sequential(
            ConvModule(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                stride=self.stride,
                dilation=self.dilation,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))
        self.conv2 = nn.Sequential(
            ConvModule(
                mid_channels,
                out_channels,
                kernel_size=1,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))

        if residual and in_channels != out_channels:
            self.short_cut = build_conv_layer(conv_cfg, in_channels,
                                              out_channels, 1)
        else:
            self.short_cut = None

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        """Forward function."""
        if self.use_stride_conv:
            assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
        else:
            assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
                self.pad + self.causal_shift <= x.shape[2]

        out = self.conv1(x)
        if self.dropout is not None:
            out = self.dropout(out)

        out = self.conv2(out)
        if self.dropout is not None:
            out = self.dropout(out)

        if self.residual:
            if self.use_stride_conv:
                res = x[:, :, self.causal_shift +
                        self.kernel_size // 2::self.kernel_size]
            else:
                res = x[:, :,
                        (self.pad + self.causal_shift):(x.shape[2] - self.pad +
                                                        self.causal_shift)]

            if self.short_cut is not None:
                res = self.short_cut(res)
            out = out + res

        return out


@BACKBONES.register_module()
class TCN(BaseBackbone):
    """TCN backbone.

    Temporal Convolutional Networks.
    More details can be found in the
    `paper <https://arxiv.org/abs/1811.11742>`__ .

    Args:
        in_channels (int): Number of input channels, which equals to
            num_keypoints * num_features.
        stem_channels (int): Number of feature channels. Default: 1024.
        num_blocks (int): NUmber of basic temporal convolutional blocks.
            Default: 2.
        kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
            each basic block. Default: ``(3, 3, 3)``.
        dropout (float): Dropout rate. Default: 0.25.
        causal (bool): Use causal convolutions instead of symmetric
            convolutions (for real-time applications).
            Default: False.
        residual (bool): Use residual connection. Default: True.
        use_stride_conv (bool): Use TCN backbone optimized for
            single-frame batching, i.e. where batches have input length =
            receptive field, and output length = 1. This implementation
            replaces dilated convolutions with strided convolutions to avoid
            generating unused intermediate results. The weights are
            interchangeable with the reference implementation. Default: False
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: dict(type='Conv1d').
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN1d').
        max_norm (float|None): if not None, the weight of convolution layers
            will be clipped to have a maximum norm of max_norm.

    Example:
        >>> from mmpose.models import TCN
        >>> import torch
        >>> self = TCN(in_channels=34)
        >>> self.eval()
        >>> inputs = torch.rand(1, 34, 243)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 1024, 235)
        (1, 1024, 217)
    """

    def __init__(self,
                 in_channels,
                 stem_channels=1024,
                 num_blocks=2,
                 kernel_sizes=(3, 3, 3),
                 dropout=0.25,
                 causal=False,
                 residual=True,
                 use_stride_conv=False,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 max_norm=None):
        # Protect mutable default arguments
        conv_cfg = copy.deepcopy(conv_cfg)
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__()
        self.in_channels = in_channels
        self.stem_channels = stem_channels
        self.num_blocks = num_blocks
        self.kernel_sizes = kernel_sizes
        self.dropout = dropout
        self.causal = causal
        self.residual = residual
        self.use_stride_conv = use_stride_conv
        self.max_norm = max_norm

        assert num_blocks == len(kernel_sizes) - 1
        for ks in kernel_sizes:
            assert ks % 2 == 1, 'Only odd filter widths are supported.'

        self.expand_conv = ConvModule(
            in_channels,
            stem_channels,
            kernel_size=kernel_sizes[0],
            stride=kernel_sizes[0] if use_stride_conv else 1,
            bias='auto',
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg)

        dilation = kernel_sizes[0]
        self.tcn_blocks = nn.ModuleList()
        for i in range(1, num_blocks + 1):
            self.tcn_blocks.append(
                BasicTemporalBlock(
                    in_channels=stem_channels,
                    out_channels=stem_channels,
                    mid_channels=stem_channels,
                    kernel_size=kernel_sizes[i],
                    dilation=dilation,
                    dropout=dropout,
                    causal=causal,
                    residual=residual,
                    use_stride_conv=use_stride_conv,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg))
            dilation *= kernel_sizes[i]

        if self.max_norm is not None:
            # Apply weight norm clip to conv layers
            weight_clip = WeightNormClipHook(self.max_norm)
            for module in self.modules():
                if isinstance(module, nn.modules.conv._ConvNd):
                    weight_clip.register(module)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        """Forward function."""
        x = self.expand_conv(x)

        if self.dropout is not None:
            x = self.dropout(x)

        outs = []
        for i in range(self.num_blocks):
            x = self.tcn_blocks[i](x)
            outs.append(x)

        return tuple(outs)

    def init_weights(self, pretrained=None):
        """Initialize the weights."""
        super().init_weights(pretrained)
        if pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.modules.conv._ConvNd):
                    kaiming_init(m, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)