File size: 3,109 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Optional

import torch
from torch.nn import Module, Sequential, Parameter

from tha3.module.module_factory import ModuleFactory
from tha3.nn.conv import create_conv1
from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha3.nn.normalization import NormalizationLayerFactory
from tha3.nn.separable_conv import create_separable_conv3
from tha3.nn.util import BlockArgs


class ResnetBlockSeparable(Module):
    @staticmethod
    def create(num_channels: int,
               is1x1: bool = False,
               use_scale_parameters: bool = False,
               block_args: Optional[BlockArgs] = None):
        if block_args is None:
            block_args = BlockArgs()
        return ResnetBlockSeparable(
            num_channels,
            is1x1,
            block_args.initialization_method,
            block_args.nonlinearity_factory,
            block_args.normalization_layer_factory,
            block_args.use_spectral_norm,
            use_scale_parameters)

    def __init__(self,
                 num_channels: int,
                 is1x1: bool = False,
                 initialization_method: str = 'he',
                 nonlinearity_factory: ModuleFactory = None,
                 normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
                 use_spectral_norm: bool = False,
                 use_scale_parameter: bool = False):
        super().__init__()
        self.use_scale_parameter = use_scale_parameter
        if self.use_scale_parameter:
            self.scale = Parameter(torch.zeros(1))
        nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
        if is1x1:
            self.resnet_path = Sequential(
                create_conv1(num_channels, num_channels, initialization_method,
                             bias=True,
                             use_spectral_norm=use_spectral_norm),
                nonlinearity_factory.create(),
                create_conv1(num_channels, num_channels, initialization_method,
                             bias=True,
                             use_spectral_norm=use_spectral_norm))
        else:
            self.resnet_path = Sequential(
                create_separable_conv3(
                    num_channels, num_channels,
                    bias=False, initialization_method=initialization_method,
                    use_spectral_norm=use_spectral_norm),
                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True),
                nonlinearity_factory.create(),
                create_separable_conv3(
                    num_channels, num_channels,
                    bias=False, initialization_method=initialization_method,
                    use_spectral_norm=use_spectral_norm),
                NormalizationLayerFactory.resolve_2d(normalization_layer_factory).create(num_channels, affine=True))

    def forward(self, x):
        if self.use_scale_parameter:
            return x + self.scale * self.resnet_path(x)
        else:
            return x + self.resnet_path(x)