Spanicin commited on
Commit
5b1ae50
·
verified ·
1 Parent(s): ccc8e7f

Upload 7 files

Browse files
videoretalking/models/DNet.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO
2
+ import functools
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from utils import flow_util
10
+ from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
11
+
12
+ # DNet
13
+ class DNet(nn.Module):
14
+ def __init__(self):
15
+ super(DNet, self).__init__()
16
+ self.mapping_net = MappingNet()
17
+ self.warpping_net = WarpingNet()
18
+ self.editing_net = EditingNet()
19
+
20
+ def forward(self, input_image, driving_source, stage=None):
21
+ if stage == 'warp':
22
+ descriptor = self.mapping_net(driving_source)
23
+ output = self.warpping_net(input_image, descriptor)
24
+ else:
25
+ descriptor = self.mapping_net(driving_source)
26
+ output = self.warpping_net(input_image, descriptor)
27
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
28
+ return output
29
+
30
+ class MappingNet(nn.Module):
31
+ def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
32
+ super( MappingNet, self).__init__()
33
+
34
+ self.layer = layer
35
+ nonlinearity = nn.LeakyReLU(0.1)
36
+
37
+ self.first = nn.Sequential(
38
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
39
+
40
+ for i in range(layer):
41
+ net = nn.Sequential(nonlinearity,
42
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
43
+ setattr(self, 'encoder' + str(i), net)
44
+
45
+ self.pooling = nn.AdaptiveAvgPool1d(1)
46
+ self.output_nc = descriptor_nc
47
+
48
+ def forward(self, input_3dmm):
49
+ out = self.first(input_3dmm)
50
+ for i in range(self.layer):
51
+ model = getattr(self, 'encoder' + str(i))
52
+ out = model(out) + out[:,:,3:-3]
53
+ out = self.pooling(out)
54
+ return out
55
+
56
+ class WarpingNet(nn.Module):
57
+ def __init__(
58
+ self,
59
+ image_nc=3,
60
+ descriptor_nc=256,
61
+ base_nc=32,
62
+ max_nc=256,
63
+ encoder_layer=5,
64
+ decoder_layer=3,
65
+ use_spect=False
66
+ ):
67
+ super( WarpingNet, self).__init__()
68
+
69
+ nonlinearity = nn.LeakyReLU(0.1)
70
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
71
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
72
+
73
+ self.descriptor_nc = descriptor_nc
74
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
75
+ max_nc, encoder_layer, decoder_layer, **kwargs)
76
+
77
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
78
+ nonlinearity,
79
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
80
+
81
+ self.pool = nn.AdaptiveAvgPool2d(1)
82
+
83
+ def forward(self, input_image, descriptor):
84
+ final_output={}
85
+ output = self.hourglass(input_image, descriptor)
86
+ final_output['flow_field'] = self.flow_out(output)
87
+
88
+ deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
89
+ final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
90
+ return final_output
91
+
92
+
93
+ class EditingNet(nn.Module):
94
+ def __init__(
95
+ self,
96
+ image_nc=3,
97
+ descriptor_nc=256,
98
+ layer=3,
99
+ base_nc=64,
100
+ max_nc=256,
101
+ num_res_blocks=2,
102
+ use_spect=False):
103
+ super(EditingNet, self).__init__()
104
+
105
+ nonlinearity = nn.LeakyReLU(0.1)
106
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
107
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
108
+ self.descriptor_nc = descriptor_nc
109
+
110
+ # encoder part
111
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
112
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
113
+
114
+ def forward(self, input_image, warp_image, descriptor):
115
+ x = torch.cat([input_image, warp_image], 1)
116
+ x = self.encoder(x)
117
+ gen_image = self.decoder(x, descriptor)
118
+ return gen_image
videoretalking/models/ENet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.base_blocks import ResBlock, StyleConv, ToRGB
6
+
7
+
8
+ class ENet(nn.Module):
9
+ def __init__(
10
+ self,
11
+ num_style_feat=512,
12
+ lnet=None,
13
+ concat=False
14
+ ):
15
+ super(ENet, self).__init__()
16
+
17
+ self.low_res = lnet
18
+ for param in self.low_res.parameters():
19
+ param.requires_grad = False
20
+
21
+ channel_multiplier, narrow = 2, 1
22
+ channels = {
23
+ '4': int(512 * narrow),
24
+ '8': int(512 * narrow),
25
+ '16': int(512 * narrow),
26
+ '32': int(512 * narrow),
27
+ '64': int(256 * channel_multiplier * narrow),
28
+ '128': int(128 * channel_multiplier * narrow),
29
+ '256': int(64 * channel_multiplier * narrow),
30
+ '512': int(32 * channel_multiplier * narrow),
31
+ '1024': int(16 * channel_multiplier * narrow)
32
+ }
33
+
34
+ self.log_size = 8
35
+ first_out_size = 128
36
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
37
+
38
+ # downsample
39
+ in_channels = channels[f'{first_out_size}']
40
+ self.conv_body_down = nn.ModuleList()
41
+ for i in range(8, 2, -1):
42
+ out_channels = channels[f'{2**(i - 1)}']
43
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
44
+ in_channels = out_channels
45
+
46
+ self.num_style_feat = num_style_feat
47
+ linear_out_channel = num_style_feat
48
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
49
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
50
+
51
+ self.style_convs = nn.ModuleList()
52
+ self.to_rgbs = nn.ModuleList()
53
+ self.noises = nn.Module()
54
+
55
+ self.concat = concat
56
+ if concat:
57
+ in_channels = 3 + 32 # channels['64']
58
+ else:
59
+ in_channels = 3
60
+
61
+ for i in range(7, 9): # 128, 256
62
+ out_channels = channels[f'{2**i}'] #
63
+ self.style_convs.append(
64
+ StyleConv(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=3,
68
+ num_style_feat=num_style_feat,
69
+ demodulate=True,
70
+ sample_mode='upsample'))
71
+ self.style_convs.append(
72
+ StyleConv(
73
+ out_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ num_style_feat=num_style_feat,
77
+ demodulate=True,
78
+ sample_mode=None))
79
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
80
+ in_channels = out_channels
81
+
82
+ def forward(self, audio_sequences, face_sequences, gt_sequences):
83
+ B = audio_sequences.size(0)
84
+ input_dim_size = len(face_sequences.size())
85
+ inp, ref = torch.split(face_sequences,3,dim=1)
86
+
87
+ if input_dim_size > 4:
88
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
89
+ inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
90
+ ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
91
+ gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
92
+
93
+ # get the global style
94
+ feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
95
+ for i in range(self.log_size - 2):
96
+ feat = self.conv_body_down[i](feat)
97
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
98
+
99
+ # style code
100
+ style_code = self.final_linear(feat.reshape(feat.size(0), -1))
101
+ style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
102
+
103
+ LNet_input = torch.cat([inp, gt_sequences], dim=1)
104
+ LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
105
+
106
+ if self.concat:
107
+ low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
108
+ low_res_img.detach()
109
+ low_res_feat.detach()
110
+ out = torch.cat([low_res_img, low_res_feat], dim=1)
111
+
112
+ else:
113
+ low_res_img = self.low_res(audio_sequences, LNet_input)
114
+ low_res_img.detach()
115
+ # 96 x 96
116
+ out = low_res_img
117
+
118
+ p2d = (2,2,2,2)
119
+ out = F.pad(out, p2d, "reflect", 0)
120
+ skip = out
121
+
122
+ for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
123
+ out = conv1(out, style_code) # 96, 192, 384
124
+ out = conv2(out, style_code)
125
+ skip = to_rgb(out, style_code, skip)
126
+ _outputs = skip
127
+
128
+ # remove padding
129
+ _outputs = _outputs[:,:,8:-8,8:-8]
130
+
131
+ if input_dim_size > 4:
132
+ _outputs = torch.split(_outputs, B, dim=0)
133
+ outputs = torch.stack(_outputs, dim=2)
134
+ low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
135
+ low_res_img = torch.split(low_res_img, B, dim=0)
136
+ low_res_img = torch.stack(low_res_img, dim=2)
137
+ else:
138
+ outputs = _outputs
139
+ return outputs, low_res_img
videoretalking/models/LNet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from models.transformer import RETURNX, Transformer
6
+ from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
7
+ FFCADAINResBlocks, Jump, FinalBlock2d
8
+
9
+
10
+ class Visual_Encoder(nn.Module):
11
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
12
+ super(Visual_Encoder, self).__init__()
13
+ self.layers = layers
14
+ self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
15
+ self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
16
+ for i in range(layers):
17
+ in_channels = min(ngf*(2**i), img_f)
18
+ out_channels = min(ngf*(2**(i+1)), img_f)
19
+ model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
20
+ model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
21
+ if i < 2:
22
+ ca_layer = RETURNX()
23
+ else:
24
+ ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
25
+ setattr(self, 'ca' + str(i), ca_layer)
26
+ setattr(self, 'ref_down' + str(i), model_ref)
27
+ setattr(self, 'inp_down' + str(i), model_inp)
28
+ self.output_nc = out_channels * 2
29
+
30
+ def forward(self, maskGT, ref):
31
+ x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
32
+ out=[x_maskGT]
33
+ for i in range(self.layers):
34
+ model_ref = getattr(self, 'ref_down'+str(i))
35
+ model_inp = getattr(self, 'inp_down'+str(i))
36
+ ca_layer = getattr(self, 'ca'+str(i))
37
+ x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
38
+ x_maskGT = ca_layer(x_maskGT, x_ref)
39
+ if i < self.layers - 1:
40
+ out.append(x_maskGT)
41
+ else:
42
+ out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
43
+ return out
44
+
45
+
46
+ class Decoder(nn.Module):
47
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
48
+ super(Decoder, self).__init__()
49
+ self.layers = layers
50
+ for i in range(layers)[::-1]:
51
+ if i == layers-1:
52
+ in_channels = ngf*(2**(i+1)) * 2
53
+ else:
54
+ in_channels = min(ngf*(2**(i+1)), img_f)
55
+ out_channels = min(ngf*(2**i), img_f)
56
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
57
+ res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
58
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
59
+
60
+ setattr(self, 'up' + str(i), up)
61
+ setattr(self, 'res' + str(i), res)
62
+ setattr(self, 'jump' + str(i), jump)
63
+
64
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
65
+ self.output_nc = out_channels
66
+
67
+ def forward(self, x, z):
68
+ out = x.pop()
69
+ for i in range(self.layers)[::-1]:
70
+ res_model = getattr(self, 'res' + str(i))
71
+ up_model = getattr(self, 'up' + str(i))
72
+ jump_model = getattr(self, 'jump' + str(i))
73
+ out = res_model(out, z)
74
+ out = up_model(out)
75
+ out = jump_model(x.pop()) + out
76
+ out_image = self.final(out)
77
+ return out_image
78
+
79
+
80
+ class LNet(nn.Module):
81
+ def __init__(
82
+ self,
83
+ image_nc=3,
84
+ descriptor_nc=512,
85
+ layer=3,
86
+ base_nc=64,
87
+ max_nc=512,
88
+ num_res_blocks=9,
89
+ use_spect=True,
90
+ encoder=Visual_Encoder,
91
+ decoder=Decoder
92
+ ):
93
+ super(LNet, self).__init__()
94
+
95
+ nonlinearity = nn.LeakyReLU(0.1)
96
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
97
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
98
+ self.descriptor_nc = descriptor_nc
99
+
100
+ self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
101
+ self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
102
+ self.audio_encoder = nn.Sequential(
103
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
104
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
105
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
106
+
107
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
108
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
109
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
110
+
111
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
112
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
113
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
114
+
115
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
116
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
117
+
118
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
119
+ Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
120
+ )
121
+
122
+ def forward(self, audio_sequences, face_sequences):
123
+ B = audio_sequences.size(0)
124
+ input_dim_size = len(face_sequences.size())
125
+ if input_dim_size > 4:
126
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
127
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
128
+ cropped, ref = torch.split(face_sequences, 3, dim=1)
129
+
130
+ vis_feat = self.encoder(cropped, ref)
131
+ audio_feat = self.audio_encoder(audio_sequences)
132
+ _outputs = self.decoder(vis_feat, audio_feat)
133
+
134
+ if input_dim_size > 4:
135
+ _outputs = torch.split(_outputs, B, dim=0)
136
+ outputs = torch.stack(_outputs, dim=2)
137
+ else:
138
+ outputs = _outputs
139
+ return outputs
videoretalking/models/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.DNet import DNet
3
+ from models.LNet import LNet
4
+ from models.ENet import ENet
5
+
6
+
7
+ def _load(checkpoint_path):
8
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
9
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
10
+ return checkpoint
11
+
12
+ def load_checkpoint(path, model):
13
+ print("Load checkpoint from: {}".format(path))
14
+ checkpoint = _load(path)
15
+ s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
16
+ new_s = {}
17
+ for k, v in s.items():
18
+ if 'low_res' in k:
19
+ continue
20
+ else:
21
+ new_s[k.replace('module.', '')] = v
22
+ model.load_state_dict(new_s, strict=False)
23
+ return model
24
+
25
+ def load_network(LNet_path,ENet_path):
26
+ L_net = LNet()
27
+ L_net = load_checkpoint(LNet_path, L_net)
28
+ E_net = ENet(lnet=L_net)
29
+ model = load_checkpoint(ENet_path, E_net)
30
+ return model.eval()
31
+
32
+ def load_DNet(DNet_path):
33
+ D_Net = DNet()
34
+ print("Load checkpoint from: {}".format(DNet_path))
35
+ checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
36
+ D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
37
+ return D_Net.eval()
videoretalking/models/base_blocks.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.batchnorm import BatchNorm2d
6
+ from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
7
+
8
+ from models.ffc import FFC
9
+ from basicsr.archs.arch_util import default_init_weights
10
+
11
+
12
+ class Conv2d(nn.Module):
13
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.conv_block = nn.Sequential(
16
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
17
+ nn.BatchNorm2d(cout)
18
+ )
19
+ self.act = nn.ReLU()
20
+ self.residual = residual
21
+
22
+ def forward(self, x):
23
+ out = self.conv_block(x)
24
+ if self.residual:
25
+ out += x
26
+ return self.act(out)
27
+
28
+
29
+ class ResBlock(nn.Module):
30
+ def __init__(self, in_channels, out_channels, mode='down'):
31
+ super(ResBlock, self).__init__()
32
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
33
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
34
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
35
+ if mode == 'down':
36
+ self.scale_factor = 0.5
37
+ elif mode == 'up':
38
+ self.scale_factor = 2
39
+
40
+ def forward(self, x):
41
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
42
+ # upsample/downsample
43
+ out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
44
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
45
+ # skip
46
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
47
+ skip = self.skip(x)
48
+ out = out + skip
49
+ return out
50
+
51
+
52
+ class LayerNorm2d(nn.Module):
53
+ def __init__(self, n_out, affine=True):
54
+ super(LayerNorm2d, self).__init__()
55
+ self.n_out = n_out
56
+ self.affine = affine
57
+
58
+ if self.affine:
59
+ self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
60
+ self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
61
+
62
+ def forward(self, x):
63
+ normalized_shape = x.size()[1:]
64
+ if self.affine:
65
+ return F.layer_norm(x, normalized_shape, \
66
+ self.weight.expand(normalized_shape),
67
+ self.bias.expand(normalized_shape))
68
+ else:
69
+ return F.layer_norm(x, normalized_shape)
70
+
71
+
72
+ def spectral_norm(module, use_spect=True):
73
+ if use_spect:
74
+ return SpectralNorm(module)
75
+ else:
76
+ return module
77
+
78
+
79
+ class FirstBlock2d(nn.Module):
80
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
81
+ super(FirstBlock2d, self).__init__()
82
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3}
83
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
84
+
85
+ if type(norm_layer) == type(None):
86
+ self.model = nn.Sequential(conv, nonlinearity)
87
+ else:
88
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
89
+
90
+ def forward(self, x):
91
+ out = self.model(x)
92
+ return out
93
+
94
+
95
+ class DownBlock2d(nn.Module):
96
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
97
+ super(DownBlock2d, self).__init__()
98
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
99
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
100
+ pool = nn.AvgPool2d(kernel_size=(2, 2))
101
+
102
+ if type(norm_layer) == type(None):
103
+ self.model = nn.Sequential(conv, nonlinearity, pool)
104
+ else:
105
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool)
106
+
107
+ def forward(self, x):
108
+ out = self.model(x)
109
+ return out
110
+
111
+
112
+ class UpBlock2d(nn.Module):
113
+ def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
114
+ super(UpBlock2d, self).__init__()
115
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
116
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
117
+ if type(norm_layer) == type(None):
118
+ self.model = nn.Sequential(conv, nonlinearity)
119
+ else:
120
+ self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity)
121
+
122
+ def forward(self, x):
123
+ out = self.model(F.interpolate(x, scale_factor=2))
124
+ return out
125
+
126
+
127
+ class ADAIN(nn.Module):
128
+ def __init__(self, norm_nc, feature_nc):
129
+ super().__init__()
130
+
131
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
132
+
133
+ nhidden = 128
134
+ use_bias=True
135
+
136
+ self.mlp_shared = nn.Sequential(
137
+ nn.Linear(feature_nc, nhidden, bias=use_bias),
138
+ nn.ReLU()
139
+ )
140
+ self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias)
141
+ self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias)
142
+
143
+ def forward(self, x, feature):
144
+
145
+ # Part 1. generate parameter-free normalized activations
146
+ normalized = self.param_free_norm(x)
147
+ # Part 2. produce scaling and bias conditioned on feature
148
+ feature = feature.view(feature.size(0), -1)
149
+ actv = self.mlp_shared(feature)
150
+ gamma = self.mlp_gamma(actv)
151
+ beta = self.mlp_beta(actv)
152
+
153
+ # apply scale and bias
154
+ gamma = gamma.view(*gamma.size()[:2], 1,1)
155
+ beta = beta.view(*beta.size()[:2], 1,1)
156
+ out = normalized * (1 + gamma) + beta
157
+ return out
158
+
159
+
160
+ class FineADAINResBlock2d(nn.Module):
161
+ """
162
+ Define an Residual block for different types
163
+ """
164
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
165
+ super(FineADAINResBlock2d, self).__init__()
166
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
167
+ self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
168
+ self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
169
+ self.norm1 = ADAIN(input_nc, feature_nc)
170
+ self.norm2 = ADAIN(input_nc, feature_nc)
171
+ self.actvn = nonlinearity
172
+
173
+ def forward(self, x, z):
174
+ dx = self.actvn(self.norm1(self.conv1(x), z))
175
+ dx = self.norm2(self.conv2(x), z)
176
+ out = dx + x
177
+ return out
178
+
179
+
180
+ class FineADAINResBlocks(nn.Module):
181
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
182
+ super(FineADAINResBlocks, self).__init__()
183
+ self.num_block = num_block
184
+ for i in range(num_block):
185
+ model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
186
+ setattr(self, 'res'+str(i), model)
187
+
188
+ def forward(self, x, z):
189
+ for i in range(self.num_block):
190
+ model = getattr(self, 'res'+str(i))
191
+ x = model(x, z)
192
+ return x
193
+
194
+
195
+ class ADAINEncoderBlock(nn.Module):
196
+ def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False):
197
+ super(ADAINEncoderBlock, self).__init__()
198
+ kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
199
+ kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
200
+
201
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect)
202
+ self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect)
203
+
204
+
205
+ self.norm_0 = ADAIN(input_nc, feature_nc)
206
+ self.norm_1 = ADAIN(output_nc, feature_nc)
207
+ self.actvn = nonlinearity
208
+
209
+ def forward(self, x, z):
210
+ x = self.conv_0(self.actvn(self.norm_0(x, z)))
211
+ x = self.conv_1(self.actvn(self.norm_1(x, z)))
212
+ return x
213
+
214
+
215
+ class ADAINDecoderBlock(nn.Module):
216
+ def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False):
217
+ super(ADAINDecoderBlock, self).__init__()
218
+ # Attributes
219
+ self.actvn = nonlinearity
220
+ hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc
221
+
222
+ kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1}
223
+ if use_transpose:
224
+ kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1}
225
+ else:
226
+ kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1}
227
+
228
+ # create conv layers
229
+ self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect)
230
+ if use_transpose:
231
+ self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect)
232
+ self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect)
233
+ else:
234
+ self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect),
235
+ nn.Upsample(scale_factor=2))
236
+ self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect),
237
+ nn.Upsample(scale_factor=2))
238
+ # define normalization layers
239
+ self.norm_0 = ADAIN(input_nc, feature_nc)
240
+ self.norm_1 = ADAIN(hidden_nc, feature_nc)
241
+ self.norm_s = ADAIN(input_nc, feature_nc)
242
+
243
+ def forward(self, x, z):
244
+ x_s = self.shortcut(x, z)
245
+ dx = self.conv_0(self.actvn(self.norm_0(x, z)))
246
+ dx = self.conv_1(self.actvn(self.norm_1(dx, z)))
247
+ out = x_s + dx
248
+ return out
249
+
250
+ def shortcut(self, x, z):
251
+ x_s = self.conv_s(self.actvn(self.norm_s(x, z)))
252
+ return x_s
253
+
254
+
255
+ class FineEncoder(nn.Module):
256
+ """docstring for Encoder"""
257
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
258
+ super(FineEncoder, self).__init__()
259
+ self.layers = layers
260
+ self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
261
+ for i in range(layers):
262
+ in_channels = min(ngf*(2**i), img_f)
263
+ out_channels = min(ngf*(2**(i+1)), img_f)
264
+ model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
265
+ setattr(self, 'down' + str(i), model)
266
+ self.output_nc = out_channels
267
+
268
+ def forward(self, x):
269
+ x = self.first(x)
270
+ out=[x]
271
+ for i in range(self.layers):
272
+ model = getattr(self, 'down'+str(i))
273
+ x = model(x)
274
+ out.append(x)
275
+ return out
276
+
277
+
278
+ class FineDecoder(nn.Module):
279
+ """docstring for FineDecoder"""
280
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
281
+ super(FineDecoder, self).__init__()
282
+ self.layers = layers
283
+ for i in range(layers)[::-1]:
284
+ in_channels = min(ngf*(2**(i+1)), img_f)
285
+ out_channels = min(ngf*(2**i), img_f)
286
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
287
+ res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
288
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
289
+ setattr(self, 'up' + str(i), up)
290
+ setattr(self, 'res' + str(i), res)
291
+ setattr(self, 'jump' + str(i), jump)
292
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh')
293
+ self.output_nc = out_channels
294
+
295
+ def forward(self, x, z):
296
+ out = x.pop()
297
+ for i in range(self.layers)[::-1]:
298
+ res_model = getattr(self, 'res' + str(i))
299
+ up_model = getattr(self, 'up' + str(i))
300
+ jump_model = getattr(self, 'jump' + str(i))
301
+ out = res_model(out, z)
302
+ out = up_model(out)
303
+ out = jump_model(x.pop()) + out
304
+ out_image = self.final(out)
305
+ return out_image
306
+
307
+
308
+ class ADAINEncoder(nn.Module):
309
+ def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False):
310
+ super(ADAINEncoder, self).__init__()
311
+ self.layers = layers
312
+ self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3)
313
+ for i in range(layers):
314
+ in_channels = min(ngf * (2**i), img_f)
315
+ out_channels = min(ngf *(2**(i+1)), img_f)
316
+ model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect)
317
+ setattr(self, 'encoder' + str(i), model)
318
+ self.output_nc = out_channels
319
+
320
+ def forward(self, x, z):
321
+ out = self.input_layer(x)
322
+ out_list = [out]
323
+ for i in range(self.layers):
324
+ model = getattr(self, 'encoder' + str(i))
325
+ out = model(out, z)
326
+ out_list.append(out)
327
+ return out_list
328
+
329
+
330
+ class ADAINDecoder(nn.Module):
331
+ """docstring for ADAINDecoder"""
332
+ def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True,
333
+ nonlinearity=nn.LeakyReLU(), use_spect=False):
334
+
335
+ super(ADAINDecoder, self).__init__()
336
+ self.encoder_layers = encoder_layers
337
+ self.decoder_layers = decoder_layers
338
+ self.skip_connect = skip_connect
339
+ use_transpose = True
340
+ for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]:
341
+ in_channels = min(ngf * (2**(i+1)), img_f)
342
+ in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels
343
+ out_channels = min(ngf * (2**i), img_f)
344
+ model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect)
345
+ setattr(self, 'decoder' + str(i), model)
346
+ self.output_nc = out_channels*2 if self.skip_connect else out_channels
347
+
348
+ def forward(self, x, z):
349
+ out = x.pop() if self.skip_connect else x
350
+ for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]:
351
+ model = getattr(self, 'decoder' + str(i))
352
+ out = model(out, z)
353
+ out = torch.cat([out, x.pop()], 1) if self.skip_connect else out
354
+ return out
355
+
356
+
357
+ class ADAINHourglass(nn.Module):
358
+ def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect):
359
+ super(ADAINHourglass, self).__init__()
360
+ self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect)
361
+ self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect)
362
+ self.output_nc = self.decoder.output_nc
363
+
364
+ def forward(self, x, z):
365
+ return self.decoder(self.encoder(x, z), z)
366
+
367
+
368
+ class FineADAINLama(nn.Module):
369
+ def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
370
+ super(FineADAINLama, self).__init__()
371
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
372
+ self.actvn = nonlinearity
373
+ ratio_gin = 0.75
374
+ ratio_gout = 0.75
375
+ self.ffc = FFC(input_nc, input_nc, 3,
376
+ ratio_gin, ratio_gout, 1, 1, 1,
377
+ 1, False, False, padding_type='reflect')
378
+ global_channels = int(input_nc * ratio_gout)
379
+ self.bn_l = ADAIN(input_nc - global_channels, feature_nc)
380
+ self.bn_g = ADAIN(global_channels, feature_nc)
381
+
382
+ def forward(self, x, z):
383
+ x_l, x_g = self.ffc(x)
384
+ x_l = self.actvn(self.bn_l(x_l,z))
385
+ x_g = self.actvn(self.bn_g(x_g,z))
386
+ return x_l, x_g
387
+
388
+
389
+ class FFCResnetBlock(nn.Module):
390
+ def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1,
391
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
392
+ super().__init__()
393
+ self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs)
394
+ self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs)
395
+ self.inline = True
396
+
397
+ def forward(self, x, z):
398
+ if self.inline:
399
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
400
+ else:
401
+ x_l, x_g = x if type(x) is tuple else (x, 0)
402
+
403
+ id_l, id_g = x_l, x_g
404
+ x_l, x_g = self.conv1((x_l, x_g), z)
405
+ x_l, x_g = self.conv2((x_l, x_g), z)
406
+
407
+ x_l, x_g = id_l + x_l, id_g + x_g
408
+ out = x_l, x_g
409
+ if self.inline:
410
+ out = torch.cat(out, dim=1)
411
+ return out
412
+
413
+
414
+ class FFCADAINResBlocks(nn.Module):
415
+ def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
416
+ super(FFCADAINResBlocks, self).__init__()
417
+ self.num_block = num_block
418
+ for i in range(num_block):
419
+ model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect)
420
+ setattr(self, 'res'+str(i), model)
421
+
422
+ def forward(self, x, z):
423
+ for i in range(self.num_block):
424
+ model = getattr(self, 'res'+str(i))
425
+ x = model(x, z)
426
+ return x
427
+
428
+
429
+ class Jump(nn.Module):
430
+ def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
431
+ super(Jump, self).__init__()
432
+ kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
433
+ conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect)
434
+ if type(norm_layer) == type(None):
435
+ self.model = nn.Sequential(conv, nonlinearity)
436
+ else:
437
+ self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity)
438
+
439
+ def forward(self, x):
440
+ out = self.model(x)
441
+ return out
442
+
443
+
444
+ class FinalBlock2d(nn.Module):
445
+ def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'):
446
+ super(FinalBlock2d, self).__init__()
447
+ kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3}
448
+ conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
449
+ if tanh_or_sigmoid == 'sigmoid':
450
+ out_nonlinearity = nn.Sigmoid()
451
+ else:
452
+ out_nonlinearity = nn.Tanh()
453
+ self.model = nn.Sequential(conv, out_nonlinearity)
454
+
455
+ def forward(self, x):
456
+ out = self.model(x)
457
+ return out
458
+
459
+
460
+ class ModulatedConv2d(nn.Module):
461
+ def __init__(self,
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ num_style_feat,
466
+ demodulate=True,
467
+ sample_mode=None,
468
+ eps=1e-8):
469
+ super(ModulatedConv2d, self).__init__()
470
+ self.in_channels = in_channels
471
+ self.out_channels = out_channels
472
+ self.kernel_size = kernel_size
473
+ self.demodulate = demodulate
474
+ self.sample_mode = sample_mode
475
+ self.eps = eps
476
+
477
+ # modulation inside each modulated conv
478
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
479
+ # initialization
480
+ default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
481
+
482
+ self.weight = nn.Parameter(
483
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
484
+ math.sqrt(in_channels * kernel_size**2))
485
+ self.padding = kernel_size // 2
486
+
487
+ def forward(self, x, style):
488
+ b, c, h, w = x.shape
489
+ style = self.modulation(style).view(b, 1, c, 1, 1)
490
+ weight = self.weight * style
491
+
492
+ if self.demodulate:
493
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
494
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
495
+
496
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
497
+
498
+ # upsample or downsample if necessary
499
+ if self.sample_mode == 'upsample':
500
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
501
+ elif self.sample_mode == 'downsample':
502
+ x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
503
+
504
+ b, c, h, w = x.shape
505
+ x = x.view(1, b * c, h, w)
506
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
507
+ out = out.view(b, self.out_channels, *out.shape[2:4])
508
+ return out
509
+
510
+ def __repr__(self):
511
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
512
+ f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
513
+
514
+
515
+ class StyleConv(nn.Module):
516
+ def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
517
+ super(StyleConv, self).__init__()
518
+ self.modulated_conv = ModulatedConv2d(
519
+ in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
520
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
521
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
522
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
523
+
524
+ def forward(self, x, style, noise=None):
525
+ # modulate
526
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
527
+ # noise injection
528
+ if noise is None:
529
+ b, _, h, w = out.shape
530
+ noise = out.new_empty(b, 1, h, w).normal_()
531
+ out = out + self.weight * noise
532
+ # add bias
533
+ out = out + self.bias
534
+ # activation
535
+ out = self.activate(out)
536
+ return out
537
+
538
+
539
+ class ToRGB(nn.Module):
540
+ def __init__(self, in_channels, num_style_feat, upsample=True):
541
+ super(ToRGB, self).__init__()
542
+ self.upsample = upsample
543
+ self.modulated_conv = ModulatedConv2d(
544
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
545
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
546
+
547
+ def forward(self, x, style, skip=None):
548
+ out = self.modulated_conv(x, style)
549
+ out = out + self.bias
550
+ if skip is not None:
551
+ if self.upsample:
552
+ skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
553
+ out = out + skip
554
+ return out
videoretalking/models/ffc.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fast Fourier Convolution NeurIPS 2020
2
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
3
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ # from models.modules.squeeze_excitation import SELayer
9
+ import torch.fft
10
+
11
+ class SELayer(nn.Module):
12
+ def __init__(self, channel, reduction=16):
13
+ super(SELayer, self).__init__()
14
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
15
+ self.fc = nn.Sequential(
16
+ nn.Linear(channel, channel // reduction, bias=False),
17
+ nn.ReLU(inplace=True),
18
+ nn.Linear(channel // reduction, channel, bias=False),
19
+ nn.Sigmoid()
20
+ )
21
+
22
+ def forward(self, x):
23
+ b, c, _, _ = x.size()
24
+ y = self.avg_pool(x).view(b, c)
25
+ y = self.fc(y).view(b, c, 1, 1)
26
+ res = x * y.expand_as(x)
27
+ return res
28
+
29
+
30
+ class FFCSE_block(nn.Module):
31
+ def __init__(self, channels, ratio_g):
32
+ super(FFCSE_block, self).__init__()
33
+ in_cg = int(channels * ratio_g)
34
+ in_cl = channels - in_cg
35
+ r = 16
36
+
37
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
38
+ self.conv1 = nn.Conv2d(channels, channels // r,
39
+ kernel_size=1, bias=True)
40
+ self.relu1 = nn.ReLU(inplace=True)
41
+ self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
42
+ channels // r, in_cl, kernel_size=1, bias=True)
43
+ self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
44
+ channels // r, in_cg, kernel_size=1, bias=True)
45
+ self.sigmoid = nn.Sigmoid()
46
+
47
+ def forward(self, x):
48
+ x = x if type(x) is tuple else (x, 0)
49
+ id_l, id_g = x
50
+
51
+ x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
52
+ x = self.avgpool(x)
53
+ x = self.relu1(self.conv1(x))
54
+
55
+ x_l = 0 if self.conv_a2l is None else id_l * \
56
+ self.sigmoid(self.conv_a2l(x))
57
+ x_g = 0 if self.conv_a2g is None else id_g * \
58
+ self.sigmoid(self.conv_a2g(x))
59
+ return x_l, x_g
60
+
61
+
62
+ class FourierUnit(nn.Module):
63
+
64
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
65
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
66
+ # bn_layer not used
67
+ super(FourierUnit, self).__init__()
68
+ self.groups = groups
69
+
70
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
71
+ out_channels=out_channels * 2,
72
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
73
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
74
+ self.relu = torch.nn.ReLU(inplace=True)
75
+
76
+ # squeeze and excitation block
77
+ self.use_se = use_se
78
+ if use_se:
79
+ if se_kwargs is None:
80
+ se_kwargs = {}
81
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
82
+
83
+ self.spatial_scale_factor = spatial_scale_factor
84
+ self.spatial_scale_mode = spatial_scale_mode
85
+ self.spectral_pos_encoding = spectral_pos_encoding
86
+ self.ffc3d = ffc3d
87
+ self.fft_norm = fft_norm
88
+
89
+ def forward(self, x):
90
+ batch = x.shape[0]
91
+
92
+ if self.spatial_scale_factor is not None:
93
+ orig_size = x.shape[-2:]
94
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
95
+
96
+ r_size = x.size()
97
+ # (batch, c, h, w/2+1, 2)
98
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
99
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
100
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
101
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
102
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
103
+
104
+ if self.spectral_pos_encoding:
105
+ height, width = ffted.shape[-2:]
106
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
107
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
108
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
109
+
110
+ if self.use_se:
111
+ ffted = self.se(ffted)
112
+
113
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
114
+ ffted = self.relu(self.bn(ffted))
115
+
116
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
117
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
118
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
119
+
120
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
121
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
122
+
123
+ if self.spatial_scale_factor is not None:
124
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
125
+
126
+ return output
127
+
128
+
129
+ class SpectralTransform(nn.Module):
130
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs):
131
+ # bn_layer not used
132
+ super(SpectralTransform, self).__init__()
133
+ self.enable_lfu = enable_lfu
134
+ if stride == 2:
135
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
136
+ else:
137
+ self.downsample = nn.Identity()
138
+
139
+ self.stride = stride
140
+ self.conv1 = nn.Sequential(
141
+ nn.Conv2d(in_channels, out_channels //
142
+ 2, kernel_size=1, groups=groups, bias=False),
143
+ nn.BatchNorm2d(out_channels // 2),
144
+ nn.ReLU(inplace=True)
145
+ )
146
+ self.fu = FourierUnit(
147
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
148
+ if self.enable_lfu:
149
+ self.lfu = FourierUnit(
150
+ out_channels // 2, out_channels // 2, groups)
151
+ self.conv2 = torch.nn.Conv2d(
152
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
153
+
154
+ def forward(self, x):
155
+ x = self.downsample(x)
156
+ x = self.conv1(x)
157
+ output = self.fu(x)
158
+
159
+ if self.enable_lfu:
160
+ n, c, h, w = x.shape
161
+ split_no = 2
162
+ split_s = h // split_no
163
+ xs = torch.cat(torch.split(
164
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
165
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
166
+ dim=1).contiguous()
167
+ xs = self.lfu(xs)
168
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
169
+ else:
170
+ xs = 0
171
+
172
+ output = self.conv2(x + output + xs)
173
+ return output
174
+
175
+
176
+ class FFC(nn.Module):
177
+
178
+ def __init__(self, in_channels, out_channels, kernel_size,
179
+ ratio_gin, ratio_gout, stride=1, padding=0,
180
+ dilation=1, groups=1, bias=False, enable_lfu=True,
181
+ padding_type='reflect', gated=False, **spectral_kwargs):
182
+ super(FFC, self).__init__()
183
+
184
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
185
+ self.stride = stride
186
+
187
+ in_cg = int(in_channels * ratio_gin)
188
+ in_cl = in_channels - in_cg
189
+ out_cg = int(out_channels * ratio_gout)
190
+ out_cl = out_channels - out_cg
191
+
192
+ self.ratio_gin = ratio_gin
193
+ self.ratio_gout = ratio_gout
194
+ self.global_in_num = in_cg
195
+
196
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
197
+ self.convl2l = module(in_cl, out_cl, kernel_size,
198
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
199
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
200
+ self.convl2g = module(in_cl, out_cg, kernel_size,
201
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
202
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
203
+ self.convg2l = module(in_cg, out_cl, kernel_size,
204
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
205
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
206
+ self.convg2g = module(
207
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
208
+
209
+ self.gated = gated
210
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
211
+ self.gate = module(in_channels, 2, 1)
212
+
213
+ def forward(self, x):
214
+ x_l, x_g = x if type(x) is tuple else (x, 0)
215
+ out_xl, out_xg = 0, 0
216
+
217
+ if self.gated:
218
+ total_input_parts = [x_l]
219
+ if torch.is_tensor(x_g):
220
+ total_input_parts.append(x_g)
221
+ total_input = torch.cat(total_input_parts, dim=1)
222
+
223
+ gates = torch.sigmoid(self.gate(total_input))
224
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
225
+ else:
226
+ g2l_gate, l2g_gate = 1, 1
227
+
228
+ if self.ratio_gout != 1:
229
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
230
+ if self.ratio_gout != 0:
231
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
232
+
233
+ return out_xl, out_xg
videoretalking/models/transformer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from einops import rearrange
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+
10
+
11
+ class GELU(nn.Module):
12
+ def __init__(self):
13
+ super(GELU, self).__init__()
14
+ def forward(self, x):
15
+ return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
16
+
17
+ # helpers
18
+
19
+ def pair(t):
20
+ return t if isinstance(t, tuple) else (t, t)
21
+
22
+ # classes
23
+
24
+ class PreNorm(nn.Module):
25
+ def __init__(self, dim, fn):
26
+ super().__init__()
27
+ self.norm = nn.LayerNorm(dim)
28
+ self.fn = fn
29
+ def forward(self, x, **kwargs):
30
+ return self.fn(self.norm(x), **kwargs)
31
+
32
+ class DualPreNorm(nn.Module):
33
+ def __init__(self, dim, fn):
34
+ super().__init__()
35
+ self.normx = nn.LayerNorm(dim)
36
+ self.normy = nn.LayerNorm(dim)
37
+ self.fn = fn
38
+ def forward(self, x, y, **kwargs):
39
+ return self.fn(self.normx(x), self.normy(y), **kwargs)
40
+
41
+ class FeedForward(nn.Module):
42
+ def __init__(self, dim, hidden_dim, dropout = 0.):
43
+ super().__init__()
44
+ self.net = nn.Sequential(
45
+ nn.Linear(dim, hidden_dim),
46
+ GELU(),
47
+ nn.Dropout(dropout),
48
+ nn.Linear(hidden_dim, dim),
49
+ nn.Dropout(dropout)
50
+ )
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
56
+ super().__init__()
57
+ inner_dim = dim_head * heads
58
+ project_out = not (heads == 1 and dim_head == dim)
59
+
60
+ self.heads = heads
61
+ self.scale = dim_head ** -0.5
62
+
63
+ self.attend = nn.Softmax(dim = -1)
64
+
65
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
66
+ self.to_k = nn.Linear(dim, inner_dim, bias = False)
67
+ self.to_v = nn.Linear(dim, inner_dim, bias = False)
68
+
69
+
70
+ self.to_out = nn.Sequential(
71
+ nn.Linear(inner_dim, dim),
72
+ nn.Dropout(dropout)
73
+ ) if project_out else nn.Identity()
74
+
75
+ def forward(self, x, y):
76
+ # qk = self.to_qk(x).chunk(2, dim = -1) #
77
+ q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) # q,k from the zero feature
78
+ k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) # v from the reference features
79
+ v = rearrange(self.to_v(y), 'b n (h d) -> b h n d', h = self.heads)
80
+
81
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
82
+
83
+ attn = self.attend(dots)
84
+
85
+ out = torch.matmul(attn, v)
86
+ out = rearrange(out, 'b h n d -> b n (h d)')
87
+ return self.to_out(out)
88
+
89
+ class Transformer(nn.Module):
90
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
91
+ super().__init__()
92
+ self.layers = nn.ModuleList([])
93
+ for _ in range(depth):
94
+ self.layers.append(nn.ModuleList([
95
+ DualPreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
96
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
97
+ ]))
98
+
99
+
100
+ def forward(self, x, y): # x is the cropped, y is the foreign reference
101
+ bs,c,h,w = x.size()
102
+
103
+ # img to embedding
104
+ x = x.view(bs,c,-1).permute(0,2,1)
105
+ y = y.view(bs,c,-1).permute(0,2,1)
106
+
107
+ for attn, ff in self.layers:
108
+ x = attn(x, y) + x
109
+ x = ff(x) + x
110
+
111
+ x = x.view(bs,h,w,c).permute(0,3,1,2)
112
+ return x
113
+
114
+ class RETURNX(nn.Module):
115
+ def __init__(self,):
116
+ super().__init__()
117
+
118
+ def forward(self, x, y): # x is the cropped, y is the foreign reference
119
+ return x