File size: 7,464 Bytes
36d9761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import spectral_norm

from basicsr.utils.registry import ARCH_REGISTRY
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
from .vgg_arch import VGGFeatureExtractor


class SFTUpBlock(nn.Module):
    """Spatial feature transform (SFT) with upsampling block.

    Args:
        in_channel (int): Number of input channels.
        out_channel (int): Number of output channels.
        kernel_size (int): Kernel size in convolutions. Default: 3.
        padding (int): Padding in convolutions. Default: 1.
    """

    def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
        super(SFTUpBlock, self).__init__()
        self.conv1 = nn.Sequential(
            Blur(in_channel),
            spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
            nn.LeakyReLU(0.04, True),
            # The official codes use two LeakyReLU here, so 0.04 for equivalent
        )
        self.convup = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
            nn.LeakyReLU(0.2, True),
        )

        # for SFT scale and shift
        self.scale_block = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
        self.shift_block = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
        # The official codes use sigmoid for shift block, do not know why

    def forward(self, x, updated_feat):
        out = self.conv1(x)
        # SFT
        scale = self.scale_block(updated_feat)
        shift = self.shift_block(updated_feat)
        out = out * scale + shift
        # upsample
        out = self.convup(out)
        return out


@ARCH_REGISTRY.register()
class DFDNet(nn.Module):
    """DFDNet: Deep Face Dictionary Network.

    It only processes faces with 512x512 size.

    Args:
        num_feat (int): Number of feature channels.
        dict_path (str): Path to the facial component dictionary.
    """

    def __init__(self, num_feat, dict_path):
        super().__init__()
        self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
        # part_sizes: [80, 80, 50, 110]
        channel_sizes = [128, 256, 512, 512]
        self.feature_sizes = np.array([256, 128, 64, 32])
        self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
        self.flag_dict_device = False

        # dict
        self.dict = torch.load(dict_path)

        # vgg face extractor
        self.vgg_extractor = VGGFeatureExtractor(
            layer_name_list=self.vgg_layers,
            vgg_type='vgg19',
            use_input_norm=True,
            range_norm=True,
            requires_grad=False)

        # attention block for fusing dictionary features and input features
        self.attn_blocks = nn.ModuleDict()
        for idx, feat_size in enumerate(self.feature_sizes):
            for name in self.parts:
                self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])

        # multi scale dilation block
        self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])

        # upsampling and reconstruction
        self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
        self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
        self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
        self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
        self.upsample4 = nn.Sequential(
            spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
            UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())

    def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
        """swap the features from the dictionary."""
        # get the original vgg features
        part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
        # resize original vgg features
        part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
        # use adaptive instance normalization to adjust color and illuminations
        dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
        # get similarity scores
        similarity_score = F.conv2d(part_resize_feat, dict_feat)
        similarity_score = F.softmax(similarity_score.view(-1), dim=0)
        # select the most similar features in the dict (after norm)
        select_idx = torch.argmax(similarity_score)
        swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
        # attention
        attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
        attn_feat = attn * swap_feat
        # update features
        updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
        return updated_feat

    def put_dict_to_device(self, x):
        if self.flag_dict_device is False:
            for k, v in self.dict.items():
                for kk, vv in v.items():
                    self.dict[k][kk] = vv.to(x)
            self.flag_dict_device = True

    def forward(self, x, part_locations):
        """
        Now only support testing with batch size = 0.

        Args:
            x (Tensor): Input faces with shape (b, c, 512, 512).
            part_locations (list[Tensor]): Part locations.
        """
        self.put_dict_to_device(x)
        # extract vggface features
        vgg_features = self.vgg_extractor(x)
        # update vggface features using the dictionary for each part
        updated_vgg_features = []
        batch = 0  # only supports testing with batch size = 0
        for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
            dict_features = self.dict[f'{f_size}']
            vgg_feat = vgg_features[vgg_layer]
            updated_feat = vgg_feat.clone()

            # swap features from dictionary
            for part_idx, part_name in enumerate(self.parts):
                location = (part_locations[part_idx][batch] // (512 / f_size)).int()
                updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
                                              f_size)

            updated_vgg_features.append(updated_feat)

        vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
        # use updated vgg features to modulate the upsampled features with
        # SFT (Spatial Feature Transform) scaling and shifting manner.
        upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
        upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
        upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
        upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
        out = self.upsample4(upsampled_feat)

        return out