Spanicin commited on
Commit
784616c
·
verified ·
1 Parent(s): d3c3894

Update videoretalking/models/LNet.py

Browse files
Files changed (1) hide show
  1. videoretalking/models/LNet.py +138 -138
videoretalking/models/LNet.py CHANGED
@@ -1,139 +1,139 @@
1
- import functools
2
- import torch
3
- import torch.nn as nn
4
-
5
- from models.transformer import RETURNX, Transformer
6
- from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
7
- FFCADAINResBlocks, Jump, FinalBlock2d
8
-
9
-
10
- class Visual_Encoder(nn.Module):
11
- def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
12
- super(Visual_Encoder, self).__init__()
13
- self.layers = layers
14
- self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
15
- self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
16
- for i in range(layers):
17
- in_channels = min(ngf*(2**i), img_f)
18
- out_channels = min(ngf*(2**(i+1)), img_f)
19
- model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
20
- model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
21
- if i < 2:
22
- ca_layer = RETURNX()
23
- else:
24
- ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
25
- setattr(self, 'ca' + str(i), ca_layer)
26
- setattr(self, 'ref_down' + str(i), model_ref)
27
- setattr(self, 'inp_down' + str(i), model_inp)
28
- self.output_nc = out_channels * 2
29
-
30
- def forward(self, maskGT, ref):
31
- x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
32
- out=[x_maskGT]
33
- for i in range(self.layers):
34
- model_ref = getattr(self, 'ref_down'+str(i))
35
- model_inp = getattr(self, 'inp_down'+str(i))
36
- ca_layer = getattr(self, 'ca'+str(i))
37
- x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
38
- x_maskGT = ca_layer(x_maskGT, x_ref)
39
- if i < self.layers - 1:
40
- out.append(x_maskGT)
41
- else:
42
- out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
43
- return out
44
-
45
-
46
- class Decoder(nn.Module):
47
- def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
48
- super(Decoder, self).__init__()
49
- self.layers = layers
50
- for i in range(layers)[::-1]:
51
- if i == layers-1:
52
- in_channels = ngf*(2**(i+1)) * 2
53
- else:
54
- in_channels = min(ngf*(2**(i+1)), img_f)
55
- out_channels = min(ngf*(2**i), img_f)
56
- up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
57
- res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
58
- jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
59
-
60
- setattr(self, 'up' + str(i), up)
61
- setattr(self, 'res' + str(i), res)
62
- setattr(self, 'jump' + str(i), jump)
63
-
64
- self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
65
- self.output_nc = out_channels
66
-
67
- def forward(self, x, z):
68
- out = x.pop()
69
- for i in range(self.layers)[::-1]:
70
- res_model = getattr(self, 'res' + str(i))
71
- up_model = getattr(self, 'up' + str(i))
72
- jump_model = getattr(self, 'jump' + str(i))
73
- out = res_model(out, z)
74
- out = up_model(out)
75
- out = jump_model(x.pop()) + out
76
- out_image = self.final(out)
77
- return out_image
78
-
79
-
80
- class LNet(nn.Module):
81
- def __init__(
82
- self,
83
- image_nc=3,
84
- descriptor_nc=512,
85
- layer=3,
86
- base_nc=64,
87
- max_nc=512,
88
- num_res_blocks=9,
89
- use_spect=True,
90
- encoder=Visual_Encoder,
91
- decoder=Decoder
92
- ):
93
- super(LNet, self).__init__()
94
-
95
- nonlinearity = nn.LeakyReLU(0.1)
96
- norm_layer = functools.partial(LayerNorm2d, affine=True)
97
- kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
98
- self.descriptor_nc = descriptor_nc
99
-
100
- self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
101
- self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
102
- self.audio_encoder = nn.Sequential(
103
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
104
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
105
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
106
-
107
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
108
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
109
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
110
-
111
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
112
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
113
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
114
-
115
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
116
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
117
-
118
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
119
- Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
120
- )
121
-
122
- def forward(self, audio_sequences, face_sequences):
123
- B = audio_sequences.size(0)
124
- input_dim_size = len(face_sequences.size())
125
- if input_dim_size > 4:
126
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
127
- face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
128
- cropped, ref = torch.split(face_sequences, 3, dim=1)
129
-
130
- vis_feat = self.encoder(cropped, ref)
131
- audio_feat = self.audio_encoder(audio_sequences)
132
- _outputs = self.decoder(vis_feat, audio_feat)
133
-
134
- if input_dim_size > 4:
135
- _outputs = torch.split(_outputs, B, dim=0)
136
- outputs = torch.stack(_outputs, dim=2)
137
- else:
138
- outputs = _outputs
139
  return outputs
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from videoretalking.models.transformer import RETURNX, Transformer
6
+ from videoretalking.models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \
7
+ FFCADAINResBlocks, Jump, FinalBlock2d
8
+
9
+
10
+ class Visual_Encoder(nn.Module):
11
+ def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
12
+ super(Visual_Encoder, self).__init__()
13
+ self.layers = layers
14
+ self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
15
+ self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect)
16
+ for i in range(layers):
17
+ in_channels = min(ngf*(2**i), img_f)
18
+ out_channels = min(ngf*(2**(i+1)), img_f)
19
+ model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
20
+ model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
21
+ if i < 2:
22
+ ca_layer = RETURNX()
23
+ else:
24
+ ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4)
25
+ setattr(self, 'ca' + str(i), ca_layer)
26
+ setattr(self, 'ref_down' + str(i), model_ref)
27
+ setattr(self, 'inp_down' + str(i), model_inp)
28
+ self.output_nc = out_channels * 2
29
+
30
+ def forward(self, maskGT, ref):
31
+ x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref)
32
+ out=[x_maskGT]
33
+ for i in range(self.layers):
34
+ model_ref = getattr(self, 'ref_down'+str(i))
35
+ model_inp = getattr(self, 'inp_down'+str(i))
36
+ ca_layer = getattr(self, 'ca'+str(i))
37
+ x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref)
38
+ x_maskGT = ca_layer(x_maskGT, x_ref)
39
+ if i < self.layers - 1:
40
+ out.append(x_maskGT)
41
+ else:
42
+ out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features !
43
+ return out
44
+
45
+
46
+ class Decoder(nn.Module):
47
+ def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False):
48
+ super(Decoder, self).__init__()
49
+ self.layers = layers
50
+ for i in range(layers)[::-1]:
51
+ if i == layers-1:
52
+ in_channels = ngf*(2**(i+1)) * 2
53
+ else:
54
+ in_channels = min(ngf*(2**(i+1)), img_f)
55
+ out_channels = min(ngf*(2**i), img_f)
56
+ up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect)
57
+ res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect)
58
+ jump = Jump(out_channels, norm_layer, nonlinearity, use_spect)
59
+
60
+ setattr(self, 'up' + str(i), up)
61
+ setattr(self, 'res' + str(i), res)
62
+ setattr(self, 'jump' + str(i), jump)
63
+
64
+ self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid')
65
+ self.output_nc = out_channels
66
+
67
+ def forward(self, x, z):
68
+ out = x.pop()
69
+ for i in range(self.layers)[::-1]:
70
+ res_model = getattr(self, 'res' + str(i))
71
+ up_model = getattr(self, 'up' + str(i))
72
+ jump_model = getattr(self, 'jump' + str(i))
73
+ out = res_model(out, z)
74
+ out = up_model(out)
75
+ out = jump_model(x.pop()) + out
76
+ out_image = self.final(out)
77
+ return out_image
78
+
79
+
80
+ class LNet(nn.Module):
81
+ def __init__(
82
+ self,
83
+ image_nc=3,
84
+ descriptor_nc=512,
85
+ layer=3,
86
+ base_nc=64,
87
+ max_nc=512,
88
+ num_res_blocks=9,
89
+ use_spect=True,
90
+ encoder=Visual_Encoder,
91
+ decoder=Decoder
92
+ ):
93
+ super(LNet, self).__init__()
94
+
95
+ nonlinearity = nn.LeakyReLU(0.1)
96
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
97
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
98
+ self.descriptor_nc = descriptor_nc
99
+
100
+ self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs)
101
+ self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
102
+ self.audio_encoder = nn.Sequential(
103
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
104
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
105
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
106
+
107
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
108
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
109
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
110
+
111
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
112
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
113
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
114
+
115
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
116
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
117
+
118
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
119
+ Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0),
120
+ )
121
+
122
+ def forward(self, audio_sequences, face_sequences):
123
+ B = audio_sequences.size(0)
124
+ input_dim_size = len(face_sequences.size())
125
+ if input_dim_size > 4:
126
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
127
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
128
+ cropped, ref = torch.split(face_sequences, 3, dim=1)
129
+
130
+ vis_feat = self.encoder(cropped, ref)
131
+ audio_feat = self.audio_encoder(audio_sequences)
132
+ _outputs = self.decoder(vis_feat, audio_feat)
133
+
134
+ if input_dim_size > 4:
135
+ _outputs = torch.split(_outputs, B, dim=0)
136
+ outputs = torch.stack(_outputs, dim=2)
137
+ else:
138
+ outputs = _outputs
139
  return outputs