Spaces:
Running
Running
gavinyuan
commited on
Commit
·
a104d3f
1
Parent(s):
523fb10
udpate: app.py import FSGenerator
Browse files- app.py +1 -1
- modules/layers/discriminator.py +153 -0
- modules/layers/faceshifter/hear_layers.py +60 -0
- modules/layers/faceshifter/layers.py +388 -0
- modules/layers/simswap/base_model.py +140 -0
- modules/layers/simswap/fs_networks_fix.py +223 -0
- modules/layers/simswap/pg_modules/blocks.py +325 -0
- modules/layers/simswap/pg_modules/diffaug.py +76 -0
- modules/layers/simswap/pg_modules/projected_discriminator.py +191 -0
- modules/layers/simswap/pg_modules/projector.py +160 -0
- modules/layers/smoothswap/id_embedder.py +50 -0
- modules/layers/smoothswap/resnet.py +359 -0
- modules/networks/faceshifter.py +162 -0
- modules/networks/simswap.py +230 -0
- third_party/arcface/__init__.py +2 -0
- third_party/arcface/dataloaderx.py +67 -0
- third_party/arcface/iresnet.py +311 -0
- third_party/arcface/load_dataset.py +202 -0
- third_party/arcface/margin_loss.py +463 -0
- third_party/arcface/mouth_net.py +117 -0
- third_party/arcface/mouth_net_eval.py +69 -0
- third_party/arcface/mouth_net_pl.py +358 -0
- third_party/arcface/resnet.py +2 -0
- third_party/arcface/utils_callbacks.py +141 -0
- third_party/arcface/verification.py +440 -0
app.py
CHANGED
@@ -14,7 +14,7 @@ import numpy as np
|
|
14 |
from PIL import Image
|
15 |
import tqdm
|
16 |
|
17 |
-
|
18 |
# from inference.alignment import norm_crop, norm_crop_with_M, paste_back
|
19 |
# from inference.utils import save, get_5_from_98, get_detector, get_lmk
|
20 |
# from inference.PIPNet.lib.tools import get_lmk_model, demo_image
|
|
|
14 |
from PIL import Image
|
15 |
import tqdm
|
16 |
|
17 |
+
from modules.networks.faceshifter import FSGenerator
|
18 |
# from inference.alignment import norm_crop, norm_crop_with_M, paste_back
|
19 |
# from inference.utils import save, get_5_from_98, get_detector, get_lmk
|
20 |
# from inference.PIPNet.lib.tools import get_lmk_model, demo_image
|
modules/layers/discriminator.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
|
7 |
+
def weights_init(m):
|
8 |
+
classname = m.__class__.__name__
|
9 |
+
if classname.find("Conv") != -1:
|
10 |
+
m.weight.data.normal_(0.0, 0.02)
|
11 |
+
elif classname.find("BatchNorm2d") != -1:
|
12 |
+
m.weight.data.normal_(1.0, 0.02)
|
13 |
+
m.bias.data.fill_(0)
|
14 |
+
|
15 |
+
|
16 |
+
class MultiscaleDiscriminator(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
input_nc,
|
20 |
+
ndf=64,
|
21 |
+
n_layers=3,
|
22 |
+
norm_layer=nn.BatchNorm2d,
|
23 |
+
use_sigmoid=False,
|
24 |
+
num_D=3,
|
25 |
+
getIntermFeat=False,
|
26 |
+
finetune=False,
|
27 |
+
):
|
28 |
+
super(MultiscaleDiscriminator, self).__init__()
|
29 |
+
self.num_D = num_D
|
30 |
+
self.n_layers = n_layers
|
31 |
+
self.getIntermFeat = getIntermFeat
|
32 |
+
|
33 |
+
for i in range(num_D):
|
34 |
+
netD = NLayerDiscriminator(
|
35 |
+
input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
|
36 |
+
)
|
37 |
+
if getIntermFeat:
|
38 |
+
for j in range(n_layers + 2):
|
39 |
+
setattr(
|
40 |
+
self,
|
41 |
+
"scale" + str(i) + "_layer" + str(j),
|
42 |
+
getattr(netD, "model" + str(j)),
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
setattr(self, "layer" + str(i), netD.model)
|
46 |
+
|
47 |
+
self.downsample = nn.AvgPool2d(
|
48 |
+
3, stride=2, padding=[1, 1], count_include_pad=False
|
49 |
+
)
|
50 |
+
weights_init(self)
|
51 |
+
|
52 |
+
if finetune:
|
53 |
+
self.requires_grad_(False)
|
54 |
+
for name, param in self.named_parameters():
|
55 |
+
if 'layer0' in name:
|
56 |
+
param.requires_grad = True
|
57 |
+
|
58 |
+
def singleD_forward(self, model, input):
|
59 |
+
if self.getIntermFeat:
|
60 |
+
result = [input]
|
61 |
+
for i in range(len(model)):
|
62 |
+
result.append(model[i](result[-1]))
|
63 |
+
return result[1:]
|
64 |
+
else:
|
65 |
+
return [model(input)]
|
66 |
+
|
67 |
+
def forward(self, input):
|
68 |
+
num_D = self.num_D
|
69 |
+
result = []
|
70 |
+
input_downsampled = input
|
71 |
+
for i in range(num_D):
|
72 |
+
if self.getIntermFeat:
|
73 |
+
model = [
|
74 |
+
getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j))
|
75 |
+
for j in range(self.n_layers + 2)
|
76 |
+
]
|
77 |
+
else:
|
78 |
+
model = getattr(self, "layer" + str(num_D - 1 - i))
|
79 |
+
result.append(self.singleD_forward(model, input_downsampled))
|
80 |
+
if i != (num_D - 1):
|
81 |
+
input_downsampled = self.downsample(input_downsampled)
|
82 |
+
return result
|
83 |
+
|
84 |
+
|
85 |
+
# Defines the PatchGAN discriminator with the specified arguments.
|
86 |
+
class NLayerDiscriminator(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
input_nc,
|
90 |
+
ndf=64,
|
91 |
+
n_layers=3,
|
92 |
+
norm_layer=nn.BatchNorm2d,
|
93 |
+
use_sigmoid=False,
|
94 |
+
getIntermFeat=False,
|
95 |
+
):
|
96 |
+
super(NLayerDiscriminator, self).__init__()
|
97 |
+
self.getIntermFeat = getIntermFeat
|
98 |
+
self.n_layers = n_layers
|
99 |
+
|
100 |
+
kw = 4
|
101 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
102 |
+
sequence = [
|
103 |
+
[
|
104 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
105 |
+
nn.LeakyReLU(0.2, True),
|
106 |
+
]
|
107 |
+
]
|
108 |
+
|
109 |
+
nf = ndf
|
110 |
+
for n in range(1, n_layers):
|
111 |
+
nf_prev = nf
|
112 |
+
nf = min(nf * 2, 512)
|
113 |
+
sequence += [
|
114 |
+
[
|
115 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
116 |
+
norm_layer(nf),
|
117 |
+
nn.LeakyReLU(0.2, True),
|
118 |
+
]
|
119 |
+
]
|
120 |
+
|
121 |
+
nf_prev = nf
|
122 |
+
nf = min(nf * 2, 512)
|
123 |
+
sequence += [
|
124 |
+
[
|
125 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
126 |
+
norm_layer(nf),
|
127 |
+
nn.LeakyReLU(0.2, True),
|
128 |
+
]
|
129 |
+
]
|
130 |
+
|
131 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
132 |
+
|
133 |
+
if use_sigmoid:
|
134 |
+
sequence += [[nn.Sigmoid()]]
|
135 |
+
|
136 |
+
if getIntermFeat:
|
137 |
+
for n in range(len(sequence)):
|
138 |
+
setattr(self, "model" + str(n), nn.Sequential(*sequence[n]))
|
139 |
+
else:
|
140 |
+
sequence_stream = []
|
141 |
+
for n in range(len(sequence)):
|
142 |
+
sequence_stream += sequence[n]
|
143 |
+
self.model = nn.Sequential(*sequence_stream)
|
144 |
+
|
145 |
+
def forward(self, input):
|
146 |
+
if self.getIntermFeat:
|
147 |
+
res = [input]
|
148 |
+
for n in range(self.n_layers + 2):
|
149 |
+
model = getattr(self, "model" + str(n))
|
150 |
+
res.append(model(res[-1]))
|
151 |
+
return res[1:]
|
152 |
+
else:
|
153 |
+
return self.model(input)
|
modules/layers/faceshifter/hear_layers.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
def conv4x4(in_c, out_c):
|
5 |
+
return nn.Sequential(
|
6 |
+
nn.Conv2d(in_c, out_c,kernel_size=4, stride=2, padding=1),
|
7 |
+
nn.BatchNorm2d(out_c),
|
8 |
+
nn.LeakyReLU(0.1, inplace=True),
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def deconv4x4(in_c, out_c):
|
13 |
+
return nn.Sequential(
|
14 |
+
nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1),
|
15 |
+
nn.BatchNorm2d(out_c),
|
16 |
+
nn.LeakyReLU(0.1, inplace=True),
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class Hear_Net(nn.Module):
|
21 |
+
def __init__(self):
|
22 |
+
super(Hear_Net, self).__init__()
|
23 |
+
self.down1 = conv4x4(6, 64)
|
24 |
+
self.down2 = conv4x4(64, 128)
|
25 |
+
self.down3 = conv4x4(128, 256)
|
26 |
+
self.down4 = conv4x4(256, 512)
|
27 |
+
self.down5 = conv4x4(512, 512)
|
28 |
+
|
29 |
+
self.up1 = deconv4x4(512, 512)
|
30 |
+
self.up2 = deconv4x4(512*2, 256)
|
31 |
+
self.up3 = deconv4x4(256*2, 128)
|
32 |
+
self.up4 = deconv4x4(128*2, 64)
|
33 |
+
self.up5 = nn.Conv2d(64*2, 3, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
def forward(self, x): # input:(B,6,256,256)
|
36 |
+
c1 = self.down1(x)
|
37 |
+
c2 = self.down2(c1)
|
38 |
+
c3 = self.down3(c2)
|
39 |
+
c4 = self.down4(c3)
|
40 |
+
c5 = self.down5(c4)
|
41 |
+
|
42 |
+
m1 = self.up1(c5)
|
43 |
+
m1 = torch.cat((c4, m1), dim=1)
|
44 |
+
m2 = self.up2(m1)
|
45 |
+
m2 = torch.cat((c3, m2), dim=1)
|
46 |
+
m3 = self.up3(m2)
|
47 |
+
m3 = torch.cat((c2, m3), dim=1)
|
48 |
+
m4 = self.up4(m3)
|
49 |
+
m4 = torch.cat((c1, m4), dim=1)
|
50 |
+
|
51 |
+
out = nn.functional.interpolate(m4, scale_factor=2, mode='bilinear', align_corners=True)
|
52 |
+
out = self.up5(out)
|
53 |
+
return torch.tanh(out) # output:(B,3,256,256)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == '__main__':
|
57 |
+
y_cat = torch.randn(5, 6, 256, 256)
|
58 |
+
hear = Hear_Net()
|
59 |
+
y_st = hear(y_cat)
|
60 |
+
print(y_st.shape)
|
modules/layers/faceshifter/layers.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file only for testing mask regularzation.
|
3 |
+
If it works, it will be merged with `layers.py`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class AADLayer(nn.Module):
|
12 |
+
def __init__(self, c_x, attr_c, c_id=256):
|
13 |
+
super(AADLayer, self).__init__()
|
14 |
+
self.attr_c = attr_c
|
15 |
+
self.c_id = c_id
|
16 |
+
self.c_x = c_x
|
17 |
+
|
18 |
+
self.conv1 = nn.Conv2d(
|
19 |
+
attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True
|
20 |
+
)
|
21 |
+
self.conv2 = nn.Conv2d(
|
22 |
+
attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True
|
23 |
+
)
|
24 |
+
self.fc1 = nn.Linear(c_id, c_x)
|
25 |
+
self.fc2 = nn.Linear(c_id, c_x)
|
26 |
+
self.norm = nn.InstanceNorm2d(c_x, affine=False)
|
27 |
+
|
28 |
+
self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True)
|
29 |
+
|
30 |
+
def forward(self, h_in, z_attr, z_id):
|
31 |
+
# h_in cxnxn
|
32 |
+
# zid 256x1x1
|
33 |
+
# zattr cxnxn
|
34 |
+
h = self.norm(h_in)
|
35 |
+
gamma_attr = self.conv1(z_attr)
|
36 |
+
beta_attr = self.conv2(z_attr)
|
37 |
+
|
38 |
+
gamma_id = self.fc1(z_id)
|
39 |
+
beta_id = self.fc2(z_id)
|
40 |
+
A = gamma_attr * h + beta_attr
|
41 |
+
gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
|
42 |
+
beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
|
43 |
+
I = gamma_id * h + beta_id
|
44 |
+
|
45 |
+
M = torch.sigmoid(self.conv_h(h))
|
46 |
+
|
47 |
+
out = (torch.ones_like(M).to(M.device) - M) * A + M * I
|
48 |
+
return out, torch.mean(torch.ones_like(M).to(M.device) - M, dim=[1, 2, 3])
|
49 |
+
|
50 |
+
|
51 |
+
class AAD_ResBlk(nn.Module):
|
52 |
+
def __init__(self, cin, cout, c_attr, c_id=256):
|
53 |
+
super(AAD_ResBlk, self).__init__()
|
54 |
+
self.cin = cin
|
55 |
+
self.cout = cout
|
56 |
+
|
57 |
+
self.AAD1 = AADLayer(cin, c_attr, c_id)
|
58 |
+
self.conv1 = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False)
|
59 |
+
self.relu1 = nn.ReLU(inplace=True)
|
60 |
+
|
61 |
+
self.AAD2 = AADLayer(cin, c_attr, c_id)
|
62 |
+
self.conv2 = nn.Conv2d(
|
63 |
+
cin, cout, kernel_size=3, stride=1, padding=1, bias=False
|
64 |
+
)
|
65 |
+
self.relu2 = nn.ReLU(inplace=True)
|
66 |
+
|
67 |
+
if cin != cout:
|
68 |
+
self.AAD3 = AADLayer(cin, c_attr, c_id)
|
69 |
+
self.conv3 = nn.Conv2d(
|
70 |
+
cin, cout, kernel_size=3, stride=1, padding=1, bias=False
|
71 |
+
)
|
72 |
+
self.relu3 = nn.ReLU(inplace=True)
|
73 |
+
|
74 |
+
def forward(self, h, z_attr, z_id):
|
75 |
+
x, m1_ = self.AAD1(h, z_attr, z_id)
|
76 |
+
x = self.relu1(x)
|
77 |
+
x = self.conv1(x)
|
78 |
+
|
79 |
+
x, m2_ = self.AAD2(x, z_attr, z_id)
|
80 |
+
x = self.relu2(x)
|
81 |
+
x = self.conv2(x)
|
82 |
+
|
83 |
+
m = m1_ + m2_
|
84 |
+
|
85 |
+
if self.cin != self.cout:
|
86 |
+
h, m3_ = self.AAD3(h, z_attr, z_id)
|
87 |
+
h = self.relu3(h)
|
88 |
+
h = self.conv3(h)
|
89 |
+
m += m3_
|
90 |
+
x = x + h
|
91 |
+
|
92 |
+
return x, m
|
93 |
+
|
94 |
+
|
95 |
+
def weight_init(m):
|
96 |
+
if isinstance(m, nn.Linear):
|
97 |
+
m.weight.data.normal_(0, 0.001)
|
98 |
+
m.bias.data.zero_()
|
99 |
+
if isinstance(m, nn.Conv2d):
|
100 |
+
nn.init.xavier_normal_(m.weight.data)
|
101 |
+
|
102 |
+
if isinstance(m, nn.ConvTranspose2d):
|
103 |
+
nn.init.xavier_normal_(m.weight.data)
|
104 |
+
|
105 |
+
|
106 |
+
def conv4x4(in_c, out_c, norm=nn.BatchNorm2d):
|
107 |
+
return nn.Sequential(
|
108 |
+
nn.Conv2d(
|
109 |
+
in_channels=in_c,
|
110 |
+
out_channels=out_c,
|
111 |
+
kernel_size=4,
|
112 |
+
stride=2,
|
113 |
+
padding=1,
|
114 |
+
bias=False,
|
115 |
+
),
|
116 |
+
norm(out_c),
|
117 |
+
nn.LeakyReLU(0.1, inplace=True),
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
class deconv4x4(nn.Module):
|
122 |
+
def __init__(self, in_c, out_c, norm=nn.BatchNorm2d):
|
123 |
+
super(deconv4x4, self).__init__()
|
124 |
+
self.deconv = nn.ConvTranspose2d(
|
125 |
+
in_channels=in_c,
|
126 |
+
out_channels=out_c,
|
127 |
+
kernel_size=4,
|
128 |
+
stride=2,
|
129 |
+
padding=1,
|
130 |
+
bias=False,
|
131 |
+
)
|
132 |
+
self.bn = norm(out_c)
|
133 |
+
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
134 |
+
|
135 |
+
def forward(self, input, skip):
|
136 |
+
x = self.deconv(input)
|
137 |
+
x = self.bn(x)
|
138 |
+
x = self.lrelu(x)
|
139 |
+
return torch.cat((x, skip), dim=1)
|
140 |
+
|
141 |
+
|
142 |
+
class MLAttrEncoder(nn.Module):
|
143 |
+
def __init__(self, finetune=False, downup=False):
|
144 |
+
super(MLAttrEncoder, self).__init__()
|
145 |
+
|
146 |
+
self.downup = downup
|
147 |
+
if self.downup:
|
148 |
+
self.conv00 = conv4x4(3, 16)
|
149 |
+
self.conv01 = conv4x4(16, 32)
|
150 |
+
self.deconv7 = deconv4x4(64, 16)
|
151 |
+
|
152 |
+
self.conv1 = conv4x4(3, 32)
|
153 |
+
self.conv2 = conv4x4(32, 64)
|
154 |
+
self.conv3 = conv4x4(64, 128)
|
155 |
+
self.conv4 = conv4x4(128, 256)
|
156 |
+
self.conv5 = conv4x4(256, 512)
|
157 |
+
self.conv6 = conv4x4(512, 1024)
|
158 |
+
self.conv7 = conv4x4(1024, 1024)
|
159 |
+
|
160 |
+
self.deconv1 = deconv4x4(1024, 1024)
|
161 |
+
self.deconv2 = deconv4x4(2048, 512)
|
162 |
+
self.deconv3 = deconv4x4(1024, 256)
|
163 |
+
self.deconv4 = deconv4x4(512, 128)
|
164 |
+
self.deconv5 = deconv4x4(256, 64)
|
165 |
+
self.deconv6 = deconv4x4(128, 32)
|
166 |
+
|
167 |
+
self.apply(weight_init)
|
168 |
+
|
169 |
+
self.finetune = finetune
|
170 |
+
if finetune:
|
171 |
+
for name, param in self.named_parameters():
|
172 |
+
param.requires_grad = False
|
173 |
+
if self.downup:
|
174 |
+
self.conv00.requires_grad_(True)
|
175 |
+
self.conv01.requires_grad_(True)
|
176 |
+
self.deconv7.requires_grad_(True)
|
177 |
+
|
178 |
+
def forward(self, Xt):
|
179 |
+
if self.downup:
|
180 |
+
feat0 = self.conv00(Xt) # (16,256,256)
|
181 |
+
feat1 = self.conv01(feat0) # (32,128,128)
|
182 |
+
else:
|
183 |
+
feat0 = None
|
184 |
+
feat1 = self.conv1(Xt)
|
185 |
+
# 32x128x128
|
186 |
+
|
187 |
+
feat2 = self.conv2(feat1)
|
188 |
+
# 64x64x64
|
189 |
+
feat3 = self.conv3(feat2)
|
190 |
+
# 128x32x32
|
191 |
+
feat4 = self.conv4(feat3)
|
192 |
+
# 256x16xx16
|
193 |
+
feat5 = self.conv5(feat4)
|
194 |
+
# 512x8x8
|
195 |
+
feat6 = self.conv6(feat5)
|
196 |
+
# 1024x4x4
|
197 |
+
|
198 |
+
if self.downup:
|
199 |
+
z_attr1 = self.conv7(feat6)
|
200 |
+
# 1024x2x2
|
201 |
+
z_attr2 = self.deconv1(z_attr1, feat6)
|
202 |
+
z_attr3 = self.deconv2(z_attr2, feat5)
|
203 |
+
z_attr4 = self.deconv3(z_attr3, feat4)
|
204 |
+
z_attr5 = self.deconv4(z_attr4, feat3)
|
205 |
+
z_attr6 = self.deconv5(z_attr5, feat2)
|
206 |
+
z_attr7 = self.deconv6(z_attr6, feat1) # (128,64,64)+(32,128,128)->(64,128,128)
|
207 |
+
z_attr8 = self.deconv7(z_attr7, feat0) # (64,128,128)+(16,256,256)->(32,256,256)
|
208 |
+
z_attr9 = F.interpolate(
|
209 |
+
z_attr8, scale_factor=2, mode="bilinear", align_corners=True
|
210 |
+
) # (32,512,512)
|
211 |
+
return (
|
212 |
+
z_attr1,
|
213 |
+
z_attr2,
|
214 |
+
z_attr3,
|
215 |
+
z_attr4,
|
216 |
+
z_attr5,
|
217 |
+
z_attr6,
|
218 |
+
z_attr7,
|
219 |
+
z_attr8,
|
220 |
+
z_attr9
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
z_attr1 = self.conv7(feat6)
|
224 |
+
# 1024x2x2
|
225 |
+
z_attr2 = self.deconv1(z_attr1, feat6)
|
226 |
+
z_attr3 = self.deconv2(z_attr2, feat5)
|
227 |
+
z_attr4 = self.deconv3(z_attr3, feat4)
|
228 |
+
z_attr5 = self.deconv4(z_attr4, feat3)
|
229 |
+
z_attr6 = self.deconv5(z_attr5, feat2)
|
230 |
+
z_attr7 = self.deconv6(z_attr6, feat1)
|
231 |
+
z_attr8 = F.interpolate(
|
232 |
+
z_attr7, scale_factor=2, mode="bilinear", align_corners=True
|
233 |
+
)
|
234 |
+
return (
|
235 |
+
z_attr1,
|
236 |
+
z_attr2,
|
237 |
+
z_attr3,
|
238 |
+
z_attr4,
|
239 |
+
z_attr5,
|
240 |
+
z_attr6,
|
241 |
+
z_attr7,
|
242 |
+
z_attr8,
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
class AADGenerator(nn.Module):
|
247 |
+
def __init__(self, c_id=256, finetune=False, downup=False):
|
248 |
+
super(AADGenerator, self).__init__()
|
249 |
+
self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)
|
250 |
+
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id)
|
251 |
+
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id)
|
252 |
+
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id)
|
253 |
+
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id)
|
254 |
+
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id)
|
255 |
+
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id)
|
256 |
+
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id)
|
257 |
+
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id)
|
258 |
+
|
259 |
+
self.downup = downup
|
260 |
+
if downup:
|
261 |
+
self.AADBlk8_0 = AAD_ResBlk(64, 32, 32, c_id)
|
262 |
+
self.AADBlk8_1 = AAD_ResBlk(32, 3, 32, c_id)
|
263 |
+
|
264 |
+
self.apply(weight_init)
|
265 |
+
|
266 |
+
if finetune:
|
267 |
+
for name, param in self.named_parameters():
|
268 |
+
param.requires_grad = False
|
269 |
+
self.AADBlk8_0.requires_grad_(True)
|
270 |
+
self.AADBlk8_1.requires_grad_(True)
|
271 |
+
|
272 |
+
def forward(self, z_attr, z_id):
|
273 |
+
m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1))
|
274 |
+
scale= z_attr[0].shape[2] // 2 # adaptive support for 512x512, 1024x1024
|
275 |
+
m = F.interpolate(m, scale_factor=scale, mode='bilinear', align_corners=True)
|
276 |
+
m2, m2_ = self.AADBlk1(m, z_attr[0], z_id)
|
277 |
+
m2 = F.interpolate(
|
278 |
+
m2,
|
279 |
+
scale_factor=2,
|
280 |
+
mode="bilinear",
|
281 |
+
align_corners=True,
|
282 |
+
)
|
283 |
+
m3, m3_ = self.AADBlk2(m2, z_attr[1], z_id)
|
284 |
+
m3 = F.interpolate(
|
285 |
+
m3,
|
286 |
+
scale_factor=2,
|
287 |
+
mode="bilinear",
|
288 |
+
align_corners=True,
|
289 |
+
)
|
290 |
+
m4, m4_ = self.AADBlk3(m3, z_attr[2], z_id)
|
291 |
+
m4 = F.interpolate(
|
292 |
+
m4,
|
293 |
+
scale_factor=2,
|
294 |
+
mode="bilinear",
|
295 |
+
align_corners=True,
|
296 |
+
)
|
297 |
+
m5, m5_ = self.AADBlk4(m4, z_attr[3], z_id)
|
298 |
+
m5 = F.interpolate(
|
299 |
+
m5,
|
300 |
+
scale_factor=2,
|
301 |
+
mode="bilinear",
|
302 |
+
align_corners=True,
|
303 |
+
)
|
304 |
+
m6, m6_ = self.AADBlk5(m5, z_attr[4], z_id)
|
305 |
+
m6 = F.interpolate(
|
306 |
+
m6,
|
307 |
+
scale_factor=2,
|
308 |
+
mode="bilinear",
|
309 |
+
align_corners=True,
|
310 |
+
)
|
311 |
+
m7, m7_ = self.AADBlk6(m6, z_attr[5], z_id)
|
312 |
+
m7 = F.interpolate(
|
313 |
+
m7,
|
314 |
+
scale_factor=2,
|
315 |
+
mode="bilinear",
|
316 |
+
align_corners=True,
|
317 |
+
)
|
318 |
+
m8, m8_ = self.AADBlk7(m7, z_attr[6], z_id)
|
319 |
+
m8 = F.interpolate(
|
320 |
+
m8,
|
321 |
+
scale_factor=2,
|
322 |
+
mode="bilinear",
|
323 |
+
align_corners=True,
|
324 |
+
)
|
325 |
+
|
326 |
+
if self.downup:
|
327 |
+
y0, m9_ = self.AADBlk8_0(m8, z_attr[7], z_id)
|
328 |
+
y0 = F.interpolate(y0, scale_factor=2, mode='bilinear', align_corners=True)
|
329 |
+
y1, m10_ = self.AADBlk8_1(y0, z_attr[8], z_id)
|
330 |
+
y = torch.tanh(y1)
|
331 |
+
else:
|
332 |
+
y, m9_ = self.AADBlk8(m8, z_attr[7], z_id)
|
333 |
+
y = torch.tanh(y)
|
334 |
+
return y # , m # yuange
|
335 |
+
|
336 |
+
|
337 |
+
class AEI_Net(nn.Module):
|
338 |
+
def __init__(self, c_id=512, finetune=False, downup=False):
|
339 |
+
super(AEI_Net, self).__init__()
|
340 |
+
self.encoder = MLAttrEncoder(finetune=finetune, downup=downup)
|
341 |
+
self.generator = AADGenerator(c_id, finetune=finetune, downup=downup)
|
342 |
+
|
343 |
+
def forward(self, Xt, z_id):
|
344 |
+
attr = self.encoder(Xt)
|
345 |
+
Y = self.generator(attr, z_id) # yuange
|
346 |
+
return Y, attr
|
347 |
+
|
348 |
+
def get_attr(self, X):
|
349 |
+
return self.encoder(X)
|
350 |
+
|
351 |
+
def trainable_params(self):
|
352 |
+
train_params = []
|
353 |
+
for param in self.parameters():
|
354 |
+
if param.requires_grad:
|
355 |
+
train_params.append(param)
|
356 |
+
return train_params
|
357 |
+
|
358 |
+
|
359 |
+
if __name__ == "__main__":
|
360 |
+
aie = AEI_Net(512).eval()
|
361 |
+
x = aie(torch.randn(1, 3, 512, 512), torch.randn(1, 512))
|
362 |
+
|
363 |
+
|
364 |
+
# def numel(m: torch.nn.Module, only_trainable: bool = False):
|
365 |
+
# """
|
366 |
+
# returns the total number of parameters used by `m` (only counting
|
367 |
+
# shared parameters once); if `only_trainable` is True, then only
|
368 |
+
# includes parameters with `requires_grad = True`
|
369 |
+
# """
|
370 |
+
# parameters = list(m.parameters())
|
371 |
+
# if only_trainable:
|
372 |
+
# parameters = [p for p in parameters if p.requires_grad]
|
373 |
+
# unique = {p.data_ptr(): p for p in parameters}.values()
|
374 |
+
# return sum(p.numel() for p in unique)
|
375 |
+
#
|
376 |
+
#
|
377 |
+
# print(numel(aie, True))
|
378 |
+
# print(x[0].size())
|
379 |
+
# print(len(x[-1]))
|
380 |
+
|
381 |
+
|
382 |
+
import thop
|
383 |
+
|
384 |
+
img = torch.randn(1, 3, 256, 256)
|
385 |
+
latent = torch.randn(1, 512)
|
386 |
+
net = aie
|
387 |
+
flops, params = thop.profile(net, inputs=(img, latent), verbose=False)
|
388 |
+
print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
|
modules/layers/simswap/base_model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(torch.nn.Module):
|
7 |
+
def name(self):
|
8 |
+
return 'BaseModel'
|
9 |
+
|
10 |
+
def initialize(self, opt):
|
11 |
+
self.opt = opt
|
12 |
+
self.gpu_ids = opt.gpu_ids
|
13 |
+
self.isTrain = opt.isTrain
|
14 |
+
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
|
15 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
16 |
+
|
17 |
+
def set_input(self, input):
|
18 |
+
self.input = input
|
19 |
+
|
20 |
+
def forward(self):
|
21 |
+
pass
|
22 |
+
|
23 |
+
# used in test time, no backprop
|
24 |
+
def test(self):
|
25 |
+
pass
|
26 |
+
|
27 |
+
def get_image_paths(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def optimize_parameters(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def get_current_visuals(self):
|
34 |
+
return self.input
|
35 |
+
|
36 |
+
def get_current_errors(self):
|
37 |
+
return {}
|
38 |
+
|
39 |
+
def save(self, label):
|
40 |
+
pass
|
41 |
+
|
42 |
+
# helper saving function that can be used by subclasses
|
43 |
+
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
|
44 |
+
save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
|
45 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
46 |
+
torch.save(network.cpu().state_dict(), save_path)
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
network.cuda()
|
49 |
+
|
50 |
+
def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
|
51 |
+
save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
|
52 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
53 |
+
torch.save(network.state_dict(), save_path)
|
54 |
+
|
55 |
+
# helper loading function that can be used by subclasses
|
56 |
+
def load_network(self, network, network_label, epoch_label, save_dir=''):
|
57 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
58 |
+
if not save_dir:
|
59 |
+
save_dir = self.save_dir
|
60 |
+
save_path = os.path.join(save_dir, save_filename)
|
61 |
+
if not os.path.isfile(save_path):
|
62 |
+
print('%s not exists yet!' % save_path)
|
63 |
+
if network_label == 'G':
|
64 |
+
raise ('Generator must exist!')
|
65 |
+
else:
|
66 |
+
# network.load_state_dict(torch.load(save_path))
|
67 |
+
try:
|
68 |
+
network.load_state_dict(torch.load(save_path))
|
69 |
+
except:
|
70 |
+
pretrained_dict = torch.load(save_path)
|
71 |
+
model_dict = network.state_dict()
|
72 |
+
try:
|
73 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
74 |
+
network.load_state_dict(pretrained_dict)
|
75 |
+
if self.opt.verbose:
|
76 |
+
print(
|
77 |
+
'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
|
78 |
+
except:
|
79 |
+
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
|
80 |
+
for k, v in pretrained_dict.items():
|
81 |
+
if v.size() == model_dict[k].size():
|
82 |
+
model_dict[k] = v
|
83 |
+
|
84 |
+
if sys.version_info >= (3, 0):
|
85 |
+
not_initialized = set()
|
86 |
+
else:
|
87 |
+
from sets import Set
|
88 |
+
not_initialized = Set()
|
89 |
+
|
90 |
+
for k, v in model_dict.items():
|
91 |
+
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
92 |
+
not_initialized.add(k.split('.')[0])
|
93 |
+
|
94 |
+
print(sorted(not_initialized))
|
95 |
+
network.load_state_dict(model_dict)
|
96 |
+
|
97 |
+
# helper loading function that can be used by subclasses
|
98 |
+
def load_optim(self, network, network_label, epoch_label, save_dir=''):
|
99 |
+
save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
|
100 |
+
if not save_dir:
|
101 |
+
save_dir = self.save_dir
|
102 |
+
save_path = os.path.join(save_dir, save_filename)
|
103 |
+
if not os.path.isfile(save_path):
|
104 |
+
print('%s not exists yet!' % save_path)
|
105 |
+
if network_label == 'G':
|
106 |
+
raise ('Generator must exist!')
|
107 |
+
else:
|
108 |
+
# network.load_state_dict(torch.load(save_path))
|
109 |
+
try:
|
110 |
+
network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
|
111 |
+
except:
|
112 |
+
pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
|
113 |
+
model_dict = network.state_dict()
|
114 |
+
try:
|
115 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
116 |
+
network.load_state_dict(pretrained_dict)
|
117 |
+
if self.opt.verbose:
|
118 |
+
print(
|
119 |
+
'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
|
120 |
+
except:
|
121 |
+
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
|
122 |
+
for k, v in pretrained_dict.items():
|
123 |
+
if v.size() == model_dict[k].size():
|
124 |
+
model_dict[k] = v
|
125 |
+
|
126 |
+
if sys.version_info >= (3, 0):
|
127 |
+
not_initialized = set()
|
128 |
+
else:
|
129 |
+
from sets import Set
|
130 |
+
not_initialized = Set()
|
131 |
+
|
132 |
+
for k, v in model_dict.items():
|
133 |
+
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
134 |
+
not_initialized.add(k.split('.')[0])
|
135 |
+
|
136 |
+
print(sorted(not_initialized))
|
137 |
+
network.load_state_dict(model_dict)
|
138 |
+
|
139 |
+
def update_learning_rate(self):
|
140 |
+
pass
|
modules/layers/simswap/fs_networks_fix.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import kornia
|
10 |
+
|
11 |
+
|
12 |
+
class InstanceNorm(nn.Module):
|
13 |
+
def __init__(self, epsilon=1e-8):
|
14 |
+
"""
|
15 |
+
@notice: avoid in-place ops.
|
16 |
+
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
17 |
+
"""
|
18 |
+
super(InstanceNorm, self).__init__()
|
19 |
+
self.epsilon = epsilon
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x - torch.mean(x, (2, 3), True)
|
23 |
+
tmp = torch.mul(x, x) # or x ** 2
|
24 |
+
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
25 |
+
return x * tmp
|
26 |
+
|
27 |
+
class ApplyStyle(nn.Module):
|
28 |
+
"""
|
29 |
+
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
30 |
+
"""
|
31 |
+
def __init__(self, latent_size, channels):
|
32 |
+
super(ApplyStyle, self).__init__()
|
33 |
+
self.linear = nn.Linear(latent_size, channels * 2)
|
34 |
+
|
35 |
+
def forward(self, x, latent):
|
36 |
+
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
37 |
+
shape = [-1, 2, x.size(1), 1, 1]
|
38 |
+
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
39 |
+
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
40 |
+
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
41 |
+
return x
|
42 |
+
|
43 |
+
class ResnetBlock_Adain(nn.Module):
|
44 |
+
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
45 |
+
super(ResnetBlock_Adain, self).__init__()
|
46 |
+
|
47 |
+
p = 0
|
48 |
+
conv1 = []
|
49 |
+
if padding_type == 'reflect':
|
50 |
+
conv1 += [nn.ReflectionPad2d(1)]
|
51 |
+
elif padding_type == 'replicate':
|
52 |
+
conv1 += [nn.ReplicationPad2d(1)]
|
53 |
+
elif padding_type == 'zero':
|
54 |
+
p = 1
|
55 |
+
else:
|
56 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
57 |
+
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
58 |
+
self.conv1 = nn.Sequential(*conv1)
|
59 |
+
self.style1 = ApplyStyle(latent_size, dim)
|
60 |
+
self.act1 = activation
|
61 |
+
|
62 |
+
p = 0
|
63 |
+
conv2 = []
|
64 |
+
if padding_type == 'reflect':
|
65 |
+
conv2 += [nn.ReflectionPad2d(1)]
|
66 |
+
elif padding_type == 'replicate':
|
67 |
+
conv2 += [nn.ReplicationPad2d(1)]
|
68 |
+
elif padding_type == 'zero':
|
69 |
+
p = 1
|
70 |
+
else:
|
71 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
72 |
+
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
73 |
+
self.conv2 = nn.Sequential(*conv2)
|
74 |
+
self.style2 = ApplyStyle(latent_size, dim)
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x, dlatents_in_slice):
|
78 |
+
y = self.conv1(x)
|
79 |
+
y = self.style1(y, dlatents_in_slice)
|
80 |
+
y = self.act1(y)
|
81 |
+
y = self.conv2(y)
|
82 |
+
y = self.style2(y, dlatents_in_slice)
|
83 |
+
out = x + y
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
class Generator_Adain_Upsample(nn.Module):
|
89 |
+
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
|
90 |
+
norm_layer=nn.BatchNorm2d,
|
91 |
+
padding_type='reflect',
|
92 |
+
mouth_net_param: dict = None,
|
93 |
+
):
|
94 |
+
assert (n_blocks >= 0)
|
95 |
+
super(Generator_Adain_Upsample, self).__init__()
|
96 |
+
|
97 |
+
self.latent_size = latent_size
|
98 |
+
|
99 |
+
self.mouth_net_param = mouth_net_param
|
100 |
+
if mouth_net_param.get('use'):
|
101 |
+
self.latent_size += mouth_net_param.get('feature_dim')
|
102 |
+
|
103 |
+
activation = nn.ReLU(True)
|
104 |
+
|
105 |
+
self.deep = deep
|
106 |
+
|
107 |
+
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
|
108 |
+
norm_layer(64), activation)
|
109 |
+
### downsample
|
110 |
+
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
111 |
+
norm_layer(128), activation)
|
112 |
+
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
113 |
+
norm_layer(256), activation)
|
114 |
+
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
115 |
+
norm_layer(512), activation)
|
116 |
+
|
117 |
+
if self.deep:
|
118 |
+
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
119 |
+
norm_layer(512), activation)
|
120 |
+
|
121 |
+
### resnet blocks
|
122 |
+
BN = []
|
123 |
+
for i in range(n_blocks):
|
124 |
+
BN += [
|
125 |
+
ResnetBlock_Adain(512, latent_size=self.latent_size,
|
126 |
+
padding_type=padding_type, activation=activation)]
|
127 |
+
self.BottleNeck = nn.Sequential(*BN)
|
128 |
+
|
129 |
+
if self.deep:
|
130 |
+
self.up4 = nn.Sequential(
|
131 |
+
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
132 |
+
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
133 |
+
nn.BatchNorm2d(512), activation
|
134 |
+
)
|
135 |
+
self.up3 = nn.Sequential(
|
136 |
+
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
137 |
+
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
138 |
+
nn.BatchNorm2d(256), activation
|
139 |
+
)
|
140 |
+
self.up2 = nn.Sequential(
|
141 |
+
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
142 |
+
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
143 |
+
nn.BatchNorm2d(128), activation
|
144 |
+
)
|
145 |
+
self.up1 = nn.Sequential(
|
146 |
+
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
|
147 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
148 |
+
nn.BatchNorm2d(64), activation
|
149 |
+
)
|
150 |
+
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
|
151 |
+
|
152 |
+
self.register_buffer(
|
153 |
+
name="trans_matrix",
|
154 |
+
tensor=torch.tensor(
|
155 |
+
[
|
156 |
+
[
|
157 |
+
[1.07695457, -0.03625215, -1.56352194],
|
158 |
+
[0.03625215, 1.07695457, -5.32134629],
|
159 |
+
]
|
160 |
+
],
|
161 |
+
requires_grad=False,
|
162 |
+
).float(),
|
163 |
+
)
|
164 |
+
|
165 |
+
def forward(self, source, target, net_arc, mouth_net=None):
|
166 |
+
x = target # 3*224*224
|
167 |
+
if net_arc is None:
|
168 |
+
id_vector = source
|
169 |
+
else:
|
170 |
+
with torch.no_grad():
|
171 |
+
''' 1. get id '''
|
172 |
+
# M = self.trans_matrix.repeat(source.size()[0], 1, 1)
|
173 |
+
# source = kornia.geometry.transform.warp_affine(source, M, (256, 256))
|
174 |
+
resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
|
175 |
+
id_vector = F.normalize(net_arc(resize_input), dim=-1, p=2)
|
176 |
+
|
177 |
+
''' 2. get mouth feature '''
|
178 |
+
if mouth_net is not None:
|
179 |
+
w1, h1, w2, h2 = self.mouth_net_param.get('crop_param')
|
180 |
+
mouth_input = resize_input[:, :, h1:h2, w1:w2]
|
181 |
+
mouth_feat = mouth_net(mouth_input)
|
182 |
+
id_vector = torch.cat([id_vector, mouth_feat], dim=-1) # (B,dim_id+dim_mouth)
|
183 |
+
|
184 |
+
skip1 = self.first_layer(x)
|
185 |
+
skip2 = self.down1(skip1)
|
186 |
+
skip3 = self.down2(skip2)
|
187 |
+
if self.deep:
|
188 |
+
skip4 = self.down3(skip3)
|
189 |
+
x = self.down4(skip4)
|
190 |
+
else:
|
191 |
+
x = self.down3(skip3)
|
192 |
+
bot = []
|
193 |
+
bot.append(x)
|
194 |
+
features = []
|
195 |
+
for i in range(len(self.BottleNeck)):
|
196 |
+
x = self.BottleNeck[i](x, id_vector)
|
197 |
+
bot.append(x)
|
198 |
+
|
199 |
+
if self.deep:
|
200 |
+
x = self.up4(x)
|
201 |
+
features.append(x)
|
202 |
+
x = self.up3(x)
|
203 |
+
features.append(x)
|
204 |
+
x = self.up2(x)
|
205 |
+
features.append(x)
|
206 |
+
x = self.up1(x)
|
207 |
+
features.append(x)
|
208 |
+
x = self.last_layer(x)
|
209 |
+
# x = (x + 1) / 2
|
210 |
+
|
211 |
+
# return x, bot, features, dlatents
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
import thop
|
217 |
+
|
218 |
+
img = torch.randn(1, 3, 256, 256)
|
219 |
+
latent = torch.randn(1, 512)
|
220 |
+
net = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9,
|
221 |
+
mouth_net_param={"use": False})
|
222 |
+
flops, params = thop.profile(net, inputs=(latent, img, None, None), verbose=False)
|
223 |
+
print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
|
modules/layers/simswap/pg_modules/blocks.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils import spectral_norm
|
6 |
+
|
7 |
+
|
8 |
+
### single layers
|
9 |
+
|
10 |
+
|
11 |
+
def conv2d(*args, **kwargs):
|
12 |
+
return spectral_norm(nn.Conv2d(*args, **kwargs))
|
13 |
+
|
14 |
+
|
15 |
+
def convTranspose2d(*args, **kwargs):
|
16 |
+
return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
|
17 |
+
|
18 |
+
|
19 |
+
def embedding(*args, **kwargs):
|
20 |
+
return spectral_norm(nn.Embedding(*args, **kwargs))
|
21 |
+
|
22 |
+
|
23 |
+
def linear(*args, **kwargs):
|
24 |
+
return spectral_norm(nn.Linear(*args, **kwargs))
|
25 |
+
|
26 |
+
|
27 |
+
def NormLayer(c, mode='batch'):
|
28 |
+
if mode == 'group':
|
29 |
+
return nn.GroupNorm(c//2, c)
|
30 |
+
elif mode == 'batch':
|
31 |
+
return nn.BatchNorm2d(c)
|
32 |
+
|
33 |
+
|
34 |
+
### Activations
|
35 |
+
|
36 |
+
|
37 |
+
class GLU(nn.Module):
|
38 |
+
def forward(self, x):
|
39 |
+
nc = x.size(1)
|
40 |
+
assert nc % 2 == 0, 'channels dont divide 2!'
|
41 |
+
nc = int(nc/2)
|
42 |
+
return x[:, :nc] * torch.sigmoid(x[:, nc:])
|
43 |
+
|
44 |
+
|
45 |
+
class Swish(nn.Module):
|
46 |
+
def forward(self, feat):
|
47 |
+
return feat * torch.sigmoid(feat)
|
48 |
+
|
49 |
+
|
50 |
+
### Upblocks
|
51 |
+
|
52 |
+
|
53 |
+
class InitLayer(nn.Module):
|
54 |
+
def __init__(self, nz, channel, sz=4):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.init = nn.Sequential(
|
58 |
+
convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
|
59 |
+
NormLayer(channel*2),
|
60 |
+
GLU(),
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, noise):
|
64 |
+
noise = noise.view(noise.shape[0], -1, 1, 1)
|
65 |
+
return self.init(noise)
|
66 |
+
|
67 |
+
|
68 |
+
def UpBlockSmall(in_planes, out_planes):
|
69 |
+
block = nn.Sequential(
|
70 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
71 |
+
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
72 |
+
NormLayer(out_planes*2), GLU())
|
73 |
+
return block
|
74 |
+
|
75 |
+
|
76 |
+
class UpBlockSmallCond(nn.Module):
|
77 |
+
def __init__(self, in_planes, out_planes, z_dim):
|
78 |
+
super().__init__()
|
79 |
+
self.in_planes = in_planes
|
80 |
+
self.out_planes = out_planes
|
81 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
82 |
+
self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
83 |
+
|
84 |
+
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
85 |
+
self.bn = which_bn(2*out_planes)
|
86 |
+
self.act = GLU()
|
87 |
+
|
88 |
+
def forward(self, x, c):
|
89 |
+
x = self.up(x)
|
90 |
+
x = self.conv(x)
|
91 |
+
x = self.bn(x, c)
|
92 |
+
x = self.act(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
def UpBlockBig(in_planes, out_planes):
|
97 |
+
block = nn.Sequential(
|
98 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
99 |
+
conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
|
100 |
+
NoiseInjection(),
|
101 |
+
NormLayer(out_planes*2), GLU(),
|
102 |
+
conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
|
103 |
+
NoiseInjection(),
|
104 |
+
NormLayer(out_planes*2), GLU()
|
105 |
+
)
|
106 |
+
return block
|
107 |
+
|
108 |
+
|
109 |
+
class UpBlockBigCond(nn.Module):
|
110 |
+
def __init__(self, in_planes, out_planes, z_dim):
|
111 |
+
super().__init__()
|
112 |
+
self.in_planes = in_planes
|
113 |
+
self.out_planes = out_planes
|
114 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
115 |
+
self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
|
116 |
+
self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
|
117 |
+
|
118 |
+
which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
|
119 |
+
self.bn1 = which_bn(2*out_planes)
|
120 |
+
self.bn2 = which_bn(2*out_planes)
|
121 |
+
self.act = GLU()
|
122 |
+
self.noise = NoiseInjection()
|
123 |
+
|
124 |
+
def forward(self, x, c):
|
125 |
+
# block 1
|
126 |
+
x = self.up(x)
|
127 |
+
x = self.conv1(x)
|
128 |
+
x = self.noise(x)
|
129 |
+
x = self.bn1(x, c)
|
130 |
+
x = self.act(x)
|
131 |
+
|
132 |
+
# block 2
|
133 |
+
x = self.conv2(x)
|
134 |
+
x = self.noise(x)
|
135 |
+
x = self.bn2(x, c)
|
136 |
+
x = self.act(x)
|
137 |
+
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class SEBlock(nn.Module):
|
142 |
+
def __init__(self, ch_in, ch_out):
|
143 |
+
super().__init__()
|
144 |
+
self.main = nn.Sequential(
|
145 |
+
nn.AdaptiveAvgPool2d(4),
|
146 |
+
conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
|
147 |
+
Swish(),
|
148 |
+
conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
|
149 |
+
nn.Sigmoid(),
|
150 |
+
)
|
151 |
+
|
152 |
+
def forward(self, feat_small, feat_big):
|
153 |
+
return feat_big * self.main(feat_small)
|
154 |
+
|
155 |
+
|
156 |
+
### Downblocks
|
157 |
+
|
158 |
+
|
159 |
+
class SeparableConv2d(nn.Module):
|
160 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=False):
|
161 |
+
super(SeparableConv2d, self).__init__()
|
162 |
+
self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
163 |
+
groups=in_channels, bias=bias, padding=1)
|
164 |
+
self.pointwise = conv2d(in_channels, out_channels,
|
165 |
+
kernel_size=1, bias=bias)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
out = self.depthwise(x)
|
169 |
+
out = self.pointwise(out)
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
class DownBlock(nn.Module):
|
174 |
+
def __init__(self, in_planes, out_planes, separable=False):
|
175 |
+
super().__init__()
|
176 |
+
if not separable:
|
177 |
+
self.main = nn.Sequential(
|
178 |
+
conv2d(in_planes, out_planes, 4, 2, 1),
|
179 |
+
NormLayer(out_planes),
|
180 |
+
nn.LeakyReLU(0.2, inplace=True),
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
self.main = nn.Sequential(
|
184 |
+
SeparableConv2d(in_planes, out_planes, 3),
|
185 |
+
NormLayer(out_planes),
|
186 |
+
nn.LeakyReLU(0.2, inplace=True),
|
187 |
+
nn.AvgPool2d(2, 2),
|
188 |
+
)
|
189 |
+
|
190 |
+
def forward(self, feat):
|
191 |
+
return self.main(feat)
|
192 |
+
|
193 |
+
|
194 |
+
class DownBlockPatch(nn.Module):
|
195 |
+
def __init__(self, in_planes, out_planes, separable=False):
|
196 |
+
super().__init__()
|
197 |
+
self.main = nn.Sequential(
|
198 |
+
DownBlock(in_planes, out_planes, separable),
|
199 |
+
conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
|
200 |
+
NormLayer(out_planes),
|
201 |
+
nn.LeakyReLU(0.2, inplace=True),
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, feat):
|
205 |
+
return self.main(feat)
|
206 |
+
|
207 |
+
|
208 |
+
### CSM
|
209 |
+
|
210 |
+
|
211 |
+
class ResidualConvUnit(nn.Module):
|
212 |
+
def __init__(self, cin, activation, bn):
|
213 |
+
super().__init__()
|
214 |
+
self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
|
215 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
return self.skip_add.add(self.conv(x), x)
|
219 |
+
|
220 |
+
|
221 |
+
class FeatureFusionBlock(nn.Module):
|
222 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
self.deconv = deconv
|
226 |
+
self.align_corners = align_corners
|
227 |
+
|
228 |
+
self.expand = expand
|
229 |
+
out_features = features
|
230 |
+
if self.expand==True:
|
231 |
+
out_features = features//2
|
232 |
+
|
233 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
234 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
235 |
+
|
236 |
+
def forward(self, *xs):
|
237 |
+
output = xs[0]
|
238 |
+
|
239 |
+
if len(xs) == 2:
|
240 |
+
output = self.skip_add.add(output, xs[1])
|
241 |
+
|
242 |
+
output = nn.functional.interpolate(
|
243 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
244 |
+
)
|
245 |
+
|
246 |
+
output = self.out_conv(output)
|
247 |
+
|
248 |
+
return output
|
249 |
+
|
250 |
+
|
251 |
+
### Misc
|
252 |
+
|
253 |
+
|
254 |
+
class NoiseInjection(nn.Module):
|
255 |
+
def __init__(self):
|
256 |
+
super().__init__()
|
257 |
+
self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
|
258 |
+
|
259 |
+
def forward(self, feat, noise=None):
|
260 |
+
if noise is None:
|
261 |
+
batch, _, height, width = feat.shape
|
262 |
+
noise = torch.randn(batch, 1, height, width).to(feat.device)
|
263 |
+
|
264 |
+
return feat + self.weight * noise
|
265 |
+
|
266 |
+
|
267 |
+
class CCBN(nn.Module):
|
268 |
+
''' conditional batchnorm '''
|
269 |
+
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
|
270 |
+
super().__init__()
|
271 |
+
self.output_size, self.input_size = output_size, input_size
|
272 |
+
|
273 |
+
# Prepare gain and bias layers
|
274 |
+
self.gain = which_linear(input_size, output_size)
|
275 |
+
self.bias = which_linear(input_size, output_size)
|
276 |
+
|
277 |
+
# epsilon to avoid dividing by 0
|
278 |
+
self.eps = eps
|
279 |
+
# Momentum
|
280 |
+
self.momentum = momentum
|
281 |
+
|
282 |
+
self.register_buffer('stored_mean', torch.zeros(output_size))
|
283 |
+
self.register_buffer('stored_var', torch.ones(output_size))
|
284 |
+
|
285 |
+
def forward(self, x, y):
|
286 |
+
# Calculate class-conditional gains and biases
|
287 |
+
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
|
288 |
+
bias = self.bias(y).view(y.size(0), -1, 1, 1)
|
289 |
+
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
|
290 |
+
self.training, 0.1, self.eps)
|
291 |
+
return out * gain + bias
|
292 |
+
|
293 |
+
|
294 |
+
class Interpolate(nn.Module):
|
295 |
+
"""Interpolation module."""
|
296 |
+
|
297 |
+
def __init__(self, size, mode='bilinear', align_corners=False):
|
298 |
+
"""Init.
|
299 |
+
Args:
|
300 |
+
scale_factor (float): scaling
|
301 |
+
mode (str): interpolation mode
|
302 |
+
"""
|
303 |
+
super(Interpolate, self).__init__()
|
304 |
+
|
305 |
+
self.interp = nn.functional.interpolate
|
306 |
+
self.size = size
|
307 |
+
self.mode = mode
|
308 |
+
self.align_corners = align_corners
|
309 |
+
|
310 |
+
def forward(self, x):
|
311 |
+
"""Forward pass.
|
312 |
+
Args:
|
313 |
+
x (tensor): input
|
314 |
+
Returns:
|
315 |
+
tensor: interpolated data
|
316 |
+
"""
|
317 |
+
|
318 |
+
x = self.interp(
|
319 |
+
x,
|
320 |
+
size=self.size,
|
321 |
+
mode=self.mode,
|
322 |
+
align_corners=self.align_corners,
|
323 |
+
)
|
324 |
+
|
325 |
+
return x
|
modules/layers/simswap/pg_modules/diffaug.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Differentiable Augmentation for Data-Efficient GAN Training
|
2 |
+
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
|
3 |
+
# https://arxiv.org/pdf/2006.10738
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def DiffAugment(x, policy='', channels_first=True):
|
10 |
+
if policy:
|
11 |
+
if not channels_first:
|
12 |
+
x = x.permute(0, 3, 1, 2)
|
13 |
+
for p in policy.split(','):
|
14 |
+
for f in AUGMENT_FNS[p]:
|
15 |
+
x = f(x)
|
16 |
+
if not channels_first:
|
17 |
+
x = x.permute(0, 2, 3, 1)
|
18 |
+
x = x.contiguous()
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
def rand_brightness(x):
|
23 |
+
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
def rand_saturation(x):
|
28 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
29 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def rand_contrast(x):
|
34 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
35 |
+
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
def rand_translation(x, ratio=0.125):
|
40 |
+
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
41 |
+
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
42 |
+
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
43 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
44 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
45 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
46 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
47 |
+
)
|
48 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
49 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
50 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
51 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def rand_cutout(x, ratio=0.2):
|
56 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
57 |
+
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
|
58 |
+
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
|
59 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
60 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
61 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
62 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
63 |
+
)
|
64 |
+
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
65 |
+
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
66 |
+
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
67 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
68 |
+
x = x * mask.unsqueeze(1)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
AUGMENT_FNS = {
|
73 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
74 |
+
'translation': [rand_translation],
|
75 |
+
'cutout': [rand_cutout],
|
76 |
+
}
|
modules/layers/simswap/pg_modules/projected_discriminator.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from modules.layers.simswap.pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
|
7 |
+
from modules.layers.simswap.pg_modules.projector import F_RandomProj
|
8 |
+
from modules.layers.simswap.pg_modules.diffaug import DiffAugment
|
9 |
+
|
10 |
+
|
11 |
+
class SingleDisc(nn.Module):
|
12 |
+
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
|
13 |
+
super().__init__()
|
14 |
+
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
|
15 |
+
256: 32, 512: 16, 1024: 8}
|
16 |
+
|
17 |
+
# interpolate for start sz that are not powers of two
|
18 |
+
if start_sz not in channel_dict.keys():
|
19 |
+
sizes = np.array(list(channel_dict.keys()))
|
20 |
+
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
|
21 |
+
self.start_sz = start_sz
|
22 |
+
|
23 |
+
# if given ndf, allocate all layers with the same ndf
|
24 |
+
if ndf is None:
|
25 |
+
nfc = channel_dict
|
26 |
+
else:
|
27 |
+
nfc = {k: ndf for k, v in channel_dict.items()}
|
28 |
+
|
29 |
+
# for feature map discriminators with nfc not in channel_dict
|
30 |
+
# this is the case for the pretrained backbone (midas.pretrained)
|
31 |
+
if nc is not None and head is None:
|
32 |
+
nfc[start_sz] = nc
|
33 |
+
|
34 |
+
layers = []
|
35 |
+
|
36 |
+
# Head if the initial input is the full modality
|
37 |
+
if head:
|
38 |
+
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
|
39 |
+
nn.LeakyReLU(0.2, inplace=True)]
|
40 |
+
|
41 |
+
# Down Blocks
|
42 |
+
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
|
43 |
+
while start_sz > end_sz:
|
44 |
+
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
|
45 |
+
start_sz = start_sz // 2
|
46 |
+
|
47 |
+
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
|
48 |
+
self.main = nn.Sequential(*layers)
|
49 |
+
|
50 |
+
def forward(self, x, c):
|
51 |
+
return self.main(x)
|
52 |
+
|
53 |
+
|
54 |
+
class SingleDiscCond(nn.Module):
|
55 |
+
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
|
56 |
+
super().__init__()
|
57 |
+
self.cmap_dim = cmap_dim
|
58 |
+
|
59 |
+
# midas channels
|
60 |
+
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
|
61 |
+
256: 32, 512: 16, 1024: 8}
|
62 |
+
|
63 |
+
# interpolate for start sz that are not powers of two
|
64 |
+
if start_sz not in channel_dict.keys():
|
65 |
+
sizes = np.array(list(channel_dict.keys()))
|
66 |
+
start_sz = sizes[np.argmin(abs(sizes - start_sz))]
|
67 |
+
self.start_sz = start_sz
|
68 |
+
|
69 |
+
# if given ndf, allocate all layers with the same ndf
|
70 |
+
if ndf is None:
|
71 |
+
nfc = channel_dict
|
72 |
+
else:
|
73 |
+
nfc = {k: ndf for k, v in channel_dict.items()}
|
74 |
+
|
75 |
+
# for feature map discriminators with nfc not in channel_dict
|
76 |
+
# this is the case for the pretrained backbone (midas.pretrained)
|
77 |
+
if nc is not None and head is None:
|
78 |
+
nfc[start_sz] = nc
|
79 |
+
|
80 |
+
layers = []
|
81 |
+
|
82 |
+
# Head if the initial input is the full modality
|
83 |
+
if head:
|
84 |
+
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
|
85 |
+
nn.LeakyReLU(0.2, inplace=True)]
|
86 |
+
|
87 |
+
# Down Blocks
|
88 |
+
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
|
89 |
+
while start_sz > end_sz:
|
90 |
+
layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
|
91 |
+
start_sz = start_sz // 2
|
92 |
+
self.main = nn.Sequential(*layers)
|
93 |
+
|
94 |
+
# additions for conditioning on class information
|
95 |
+
self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
|
96 |
+
self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
|
97 |
+
self.embed_proj = nn.Sequential(
|
98 |
+
nn.Linear(self.embed.embedding_dim, self.cmap_dim),
|
99 |
+
nn.LeakyReLU(0.2, inplace=True),
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, x, c):
|
103 |
+
h = self.main(x)
|
104 |
+
out = self.cls(h)
|
105 |
+
|
106 |
+
# conditioning via projection
|
107 |
+
cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
|
108 |
+
out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
109 |
+
|
110 |
+
return out
|
111 |
+
|
112 |
+
|
113 |
+
class MultiScaleD(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
channels,
|
117 |
+
resolutions,
|
118 |
+
num_discs=4,
|
119 |
+
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
120 |
+
cond=0,
|
121 |
+
separable=False,
|
122 |
+
patch=False,
|
123 |
+
**kwargs,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
assert num_discs in [1, 2, 3, 4]
|
128 |
+
|
129 |
+
# the first disc is on the lowest level of the backbone
|
130 |
+
self.disc_in_channels = channels[:num_discs]
|
131 |
+
self.disc_in_res = resolutions[:num_discs]
|
132 |
+
Disc = SingleDiscCond if cond else SingleDisc
|
133 |
+
|
134 |
+
mini_discs = []
|
135 |
+
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
|
136 |
+
start_sz = res if not patch else 16
|
137 |
+
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
|
138 |
+
self.mini_discs = nn.ModuleDict(mini_discs)
|
139 |
+
|
140 |
+
def forward(self, features, c):
|
141 |
+
all_logits = []
|
142 |
+
for k, disc in self.mini_discs.items():
|
143 |
+
res = disc(features[k], c).view(features[k].size(0), -1)
|
144 |
+
all_logits.append(res)
|
145 |
+
|
146 |
+
all_logits = torch.cat(all_logits, dim=1)
|
147 |
+
return all_logits
|
148 |
+
|
149 |
+
|
150 |
+
class ProjectedDiscriminator(torch.nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
diffaug=True,
|
154 |
+
interp224=True,
|
155 |
+
backbone_kwargs={},
|
156 |
+
**kwargs
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
self.diffaug = diffaug
|
160 |
+
self.interp224 = interp224
|
161 |
+
self.feature_network = F_RandomProj(**backbone_kwargs)
|
162 |
+
self.discriminator = MultiScaleD(
|
163 |
+
channels=self.feature_network.CHANNELS,
|
164 |
+
resolutions=self.feature_network.RESOLUTIONS,
|
165 |
+
**backbone_kwargs,
|
166 |
+
)
|
167 |
+
|
168 |
+
def train(self, mode=True):
|
169 |
+
self.feature_network = self.feature_network.train(False)
|
170 |
+
self.discriminator = self.discriminator.train(mode)
|
171 |
+
return self
|
172 |
+
|
173 |
+
def eval(self):
|
174 |
+
return self.train(False)
|
175 |
+
|
176 |
+
def get_feature(self, x):
|
177 |
+
features = self.feature_network(x, get_features=True)
|
178 |
+
return features
|
179 |
+
|
180 |
+
def forward(self, x, c):
|
181 |
+
# if self.diffaug:
|
182 |
+
# x = DiffAugment(x, policy='color,translation,cutout')
|
183 |
+
|
184 |
+
# if self.interp224:
|
185 |
+
# x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
|
186 |
+
|
187 |
+
features,backbone_features = self.feature_network(x)
|
188 |
+
logits = self.discriminator(features, c)
|
189 |
+
|
190 |
+
return logits,backbone_features
|
191 |
+
|
modules/layers/simswap/pg_modules/projector.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
from modules.layers.simswap.pg_modules.blocks import FeatureFusionBlock
|
5 |
+
|
6 |
+
|
7 |
+
def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
|
8 |
+
# shapes
|
9 |
+
out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
|
10 |
+
|
11 |
+
scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
|
12 |
+
scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
|
13 |
+
scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
|
14 |
+
scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
|
15 |
+
|
16 |
+
scratch.CHANNELS = out_channels
|
17 |
+
|
18 |
+
return scratch
|
19 |
+
|
20 |
+
|
21 |
+
def _make_scratch_csm(scratch, in_channels, cout, expand):
|
22 |
+
scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
|
23 |
+
scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
|
24 |
+
scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
|
25 |
+
scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
|
26 |
+
|
27 |
+
# last refinenet does not expand to save channels in higher dimensions
|
28 |
+
scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
|
29 |
+
|
30 |
+
return scratch
|
31 |
+
|
32 |
+
|
33 |
+
def _make_efficientnet(model):
|
34 |
+
pretrained = nn.Module()
|
35 |
+
pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
|
36 |
+
pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
|
37 |
+
pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
|
38 |
+
pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
|
39 |
+
return pretrained
|
40 |
+
|
41 |
+
|
42 |
+
def calc_channels(pretrained, inp_res=224):
|
43 |
+
channels = []
|
44 |
+
tmp = torch.zeros(1, 3, inp_res, inp_res)
|
45 |
+
|
46 |
+
# forward pass
|
47 |
+
tmp = pretrained.layer0(tmp)
|
48 |
+
channels.append(tmp.shape[1])
|
49 |
+
tmp = pretrained.layer1(tmp)
|
50 |
+
channels.append(tmp.shape[1])
|
51 |
+
tmp = pretrained.layer2(tmp)
|
52 |
+
channels.append(tmp.shape[1])
|
53 |
+
tmp = pretrained.layer3(tmp)
|
54 |
+
channels.append(tmp.shape[1])
|
55 |
+
|
56 |
+
return channels
|
57 |
+
|
58 |
+
|
59 |
+
def _make_projector(im_res, cout, proj_type, expand=False):
|
60 |
+
assert proj_type in [0, 1, 2], "Invalid projection type"
|
61 |
+
|
62 |
+
### Build pretrained feature network
|
63 |
+
model = timm.create_model('tf_efficientnet_lite0', pretrained=False,
|
64 |
+
checkpoint_path='/gavin/code/FaceSwapping/modules/third_party/efficientnet/'
|
65 |
+
'tf_efficientnet_lite0-0aa007d2.pth')
|
66 |
+
pretrained = _make_efficientnet(model)
|
67 |
+
|
68 |
+
# determine resolution of feature maps, this is later used to calculate the number
|
69 |
+
# of down blocks in the discriminators. Interestingly, the best results are achieved
|
70 |
+
# by fixing this to 256, ie., we use the same number of down blocks per discriminator
|
71 |
+
# independent of the dataset resolution
|
72 |
+
im_res = 256
|
73 |
+
pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
|
74 |
+
pretrained.CHANNELS = calc_channels(pretrained)
|
75 |
+
|
76 |
+
if proj_type == 0: return pretrained, None
|
77 |
+
|
78 |
+
### Build CCM
|
79 |
+
scratch = nn.Module()
|
80 |
+
scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
|
81 |
+
pretrained.CHANNELS = scratch.CHANNELS
|
82 |
+
|
83 |
+
if proj_type == 1: return pretrained, scratch
|
84 |
+
|
85 |
+
### build CSM
|
86 |
+
scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
|
87 |
+
|
88 |
+
# CSM upsamples x2 so the feature map resolution doubles
|
89 |
+
pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
|
90 |
+
pretrained.CHANNELS = scratch.CHANNELS
|
91 |
+
|
92 |
+
return pretrained, scratch
|
93 |
+
|
94 |
+
|
95 |
+
class F_RandomProj(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
im_res=256,
|
99 |
+
cout=64,
|
100 |
+
expand=True,
|
101 |
+
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
self.proj_type = proj_type
|
106 |
+
self.cout = cout
|
107 |
+
self.expand = expand
|
108 |
+
|
109 |
+
# build pretrained feature network and random decoder (scratch)
|
110 |
+
self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
|
111 |
+
self.CHANNELS = self.pretrained.CHANNELS
|
112 |
+
self.RESOLUTIONS = self.pretrained.RESOLUTIONS
|
113 |
+
|
114 |
+
def forward(self, x, get_features=False):
|
115 |
+
# predict feature maps
|
116 |
+
out0 = self.pretrained.layer0(x)
|
117 |
+
out1 = self.pretrained.layer1(out0)
|
118 |
+
out2 = self.pretrained.layer2(out1)
|
119 |
+
out3 = self.pretrained.layer3(out2)
|
120 |
+
|
121 |
+
# start enumerating at the lowest layer (this is where we put the first discriminator)
|
122 |
+
backbone_features = {
|
123 |
+
'0': out0,
|
124 |
+
'1': out1,
|
125 |
+
'2': out2,
|
126 |
+
'3': out3,
|
127 |
+
}
|
128 |
+
if get_features:
|
129 |
+
return backbone_features
|
130 |
+
|
131 |
+
if self.proj_type == 0: return backbone_features
|
132 |
+
|
133 |
+
out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0'])
|
134 |
+
out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1'])
|
135 |
+
out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2'])
|
136 |
+
out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3'])
|
137 |
+
|
138 |
+
out = {
|
139 |
+
'0': out0_channel_mixed,
|
140 |
+
'1': out1_channel_mixed,
|
141 |
+
'2': out2_channel_mixed,
|
142 |
+
'3': out3_channel_mixed,
|
143 |
+
}
|
144 |
+
|
145 |
+
if self.proj_type == 1: return out
|
146 |
+
|
147 |
+
# from bottom to top
|
148 |
+
out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
|
149 |
+
out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
|
150 |
+
out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
|
151 |
+
out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
|
152 |
+
|
153 |
+
out = {
|
154 |
+
'0': out0_scale_mixed,
|
155 |
+
'1': out1_scale_mixed,
|
156 |
+
'2': out2_scale_mixed,
|
157 |
+
'3': out3_scale_mixed,
|
158 |
+
}
|
159 |
+
|
160 |
+
return out, backbone_features
|
modules/layers/smoothswap/id_embedder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from modules.layers.smoothswap.resnet import resnet50
|
6 |
+
|
7 |
+
|
8 |
+
class IdentityHead(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(IdentityHead, self).__init__()
|
11 |
+
self.fc1 = nn.Sequential(
|
12 |
+
nn.Linear(512 * 4, 1024),
|
13 |
+
nn.BatchNorm1d(num_features=1024),
|
14 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
15 |
+
)
|
16 |
+
self.fc2 = nn.Sequential(
|
17 |
+
nn.Linear(1024, 512),
|
18 |
+
nn.BatchNorm1d(num_features=512)
|
19 |
+
)
|
20 |
+
|
21 |
+
for m in self.modules():
|
22 |
+
if isinstance(m, (nn.BatchNorm2d,)):
|
23 |
+
nn.init.constant_(m.weight, 1)
|
24 |
+
nn.init.constant_(m.bias, 0)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.fc1(x)
|
28 |
+
x = self.fc2(x)
|
29 |
+
x = F.normalize(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class IdentityEmbedder(nn.Module):
|
34 |
+
def __init__(self):
|
35 |
+
super(IdentityEmbedder, self).__init__()
|
36 |
+
|
37 |
+
self.backbone = resnet50(pretrained=False)
|
38 |
+
self.head = IdentityHead()
|
39 |
+
|
40 |
+
def forward(self, x_src):
|
41 |
+
x_src = self.backbone(x_src)
|
42 |
+
x_src = self.head(x_src)
|
43 |
+
return x_src
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
img = torch.randn((11, 3, 256, 256)).cuda()
|
48 |
+
net = IdentityEmbedder().cuda()
|
49 |
+
out = net(img)
|
50 |
+
print(out.shape)
|
modules/layers/smoothswap/resnet.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
# from .utils import load_state_dict_from_url
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
8 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
9 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
10 |
+
|
11 |
+
|
12 |
+
model_urls = {
|
13 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
14 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
15 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
16 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
17 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
18 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
19 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
20 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
21 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
26 |
+
"""3x3 convolution with padding"""
|
27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
28 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
29 |
+
|
30 |
+
|
31 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
32 |
+
"""1x1 convolution"""
|
33 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
34 |
+
|
35 |
+
|
36 |
+
class BasicBlock(nn.Module):
|
37 |
+
expansion = 1
|
38 |
+
|
39 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
40 |
+
base_width=64, dilation=1, norm_layer=None):
|
41 |
+
super(BasicBlock, self).__init__()
|
42 |
+
if norm_layer is None:
|
43 |
+
norm_layer = nn.BatchNorm2d
|
44 |
+
if groups != 1 or base_width != 64:
|
45 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
46 |
+
if dilation > 1:
|
47 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
48 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
49 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
50 |
+
self.bn1 = norm_layer(planes)
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
self.conv2 = conv3x3(planes, planes)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
self.downsample = downsample
|
55 |
+
self.stride = stride
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.conv1(x)
|
61 |
+
out = self.bn1(out)
|
62 |
+
out = self.relu(out)
|
63 |
+
|
64 |
+
out = self.conv2(out)
|
65 |
+
out = self.bn2(out)
|
66 |
+
|
67 |
+
if self.downsample is not None:
|
68 |
+
identity = self.downsample(x)
|
69 |
+
|
70 |
+
out += identity
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Bottleneck(nn.Module):
|
77 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
78 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
79 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
80 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
81 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
82 |
+
|
83 |
+
expansion = 4
|
84 |
+
|
85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
86 |
+
base_width=64, dilation=1, norm_layer=None):
|
87 |
+
super(Bottleneck, self).__init__()
|
88 |
+
if norm_layer is None:
|
89 |
+
norm_layer = nn.BatchNorm2d
|
90 |
+
width = int(planes * (base_width / 64.)) * groups
|
91 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
92 |
+
self.conv1 = conv1x1(inplanes, width)
|
93 |
+
self.bn1 = norm_layer(width)
|
94 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
95 |
+
self.bn2 = norm_layer(width)
|
96 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
97 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
98 |
+
self.relu = nn.ReLU(inplace=True)
|
99 |
+
self.downsample = downsample
|
100 |
+
self.stride = stride
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
identity = x
|
104 |
+
|
105 |
+
out = self.conv1(x)
|
106 |
+
out = self.bn1(out)
|
107 |
+
out = self.relu(out)
|
108 |
+
|
109 |
+
out = self.conv2(out)
|
110 |
+
out = self.bn2(out)
|
111 |
+
out = self.relu(out)
|
112 |
+
|
113 |
+
out = self.conv3(out)
|
114 |
+
out = self.bn3(out)
|
115 |
+
|
116 |
+
if self.downsample is not None:
|
117 |
+
identity = self.downsample(x)
|
118 |
+
|
119 |
+
out += identity
|
120 |
+
out = self.relu(out)
|
121 |
+
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class ResNet(nn.Module):
|
126 |
+
|
127 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
128 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
129 |
+
norm_layer=None):
|
130 |
+
super(ResNet, self).__init__()
|
131 |
+
if norm_layer is None:
|
132 |
+
norm_layer = nn.BatchNorm2d
|
133 |
+
self._norm_layer = norm_layer
|
134 |
+
|
135 |
+
self.inplanes = 64
|
136 |
+
self.dilation = 1
|
137 |
+
if replace_stride_with_dilation is None:
|
138 |
+
# each element in the tuple indicates if we should replace
|
139 |
+
# the 2x2 stride with a dilated convolution instead
|
140 |
+
replace_stride_with_dilation = [False, False, False]
|
141 |
+
if len(replace_stride_with_dilation) != 3:
|
142 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
143 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
144 |
+
self.groups = groups
|
145 |
+
self.base_width = width_per_group
|
146 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
147 |
+
bias=False)
|
148 |
+
self.bn1 = norm_layer(self.inplanes)
|
149 |
+
self.relu = nn.ReLU(inplace=True)
|
150 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
151 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
152 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
153 |
+
dilate=replace_stride_with_dilation[0])
|
154 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
155 |
+
dilate=replace_stride_with_dilation[1])
|
156 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
157 |
+
dilate=replace_stride_with_dilation[2])
|
158 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
159 |
+
|
160 |
+
''' head '''
|
161 |
+
# op1. vanilla ResNet
|
162 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
163 |
+
|
164 |
+
# op2. smooth-swap resnet
|
165 |
+
# FC is defined in id_embedder.py
|
166 |
+
|
167 |
+
for m in self.modules():
|
168 |
+
if isinstance(m, nn.Conv2d):
|
169 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
170 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
171 |
+
nn.init.constant_(m.weight, 1)
|
172 |
+
nn.init.constant_(m.bias, 0)
|
173 |
+
|
174 |
+
# Zero-initialize the last BN in each residual branch,
|
175 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
176 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
177 |
+
if zero_init_residual:
|
178 |
+
for m in self.modules():
|
179 |
+
if isinstance(m, Bottleneck):
|
180 |
+
nn.init.constant_(m.bn3.weight, 0)
|
181 |
+
elif isinstance(m, BasicBlock):
|
182 |
+
nn.init.constant_(m.bn2.weight, 0)
|
183 |
+
|
184 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
185 |
+
norm_layer = self._norm_layer
|
186 |
+
downsample = None
|
187 |
+
previous_dilation = self.dilation
|
188 |
+
if dilate:
|
189 |
+
self.dilation *= stride
|
190 |
+
stride = 1
|
191 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
192 |
+
downsample = nn.Sequential(
|
193 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
194 |
+
norm_layer(planes * block.expansion),
|
195 |
+
)
|
196 |
+
|
197 |
+
layers = []
|
198 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
199 |
+
self.base_width, previous_dilation, norm_layer))
|
200 |
+
self.inplanes = planes * block.expansion
|
201 |
+
for _ in range(1, blocks):
|
202 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
203 |
+
base_width=self.base_width, dilation=self.dilation,
|
204 |
+
norm_layer=norm_layer))
|
205 |
+
|
206 |
+
return nn.Sequential(*layers)
|
207 |
+
|
208 |
+
def _forward_impl(self, x):
|
209 |
+
# See note [TorchScript super()]
|
210 |
+
x = self.conv1(x)
|
211 |
+
x = self.bn1(x)
|
212 |
+
x = self.relu(x)
|
213 |
+
x = self.maxpool(x)
|
214 |
+
|
215 |
+
x = self.layer1(x)
|
216 |
+
x = self.layer2(x)
|
217 |
+
x = self.layer3(x)
|
218 |
+
x = self.layer4(x)
|
219 |
+
|
220 |
+
x = self.avgpool(x)
|
221 |
+
x = torch.flatten(x, 1)
|
222 |
+
|
223 |
+
return x
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
return self._forward_impl(x)
|
227 |
+
|
228 |
+
|
229 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
230 |
+
model = ResNet(block, layers, **kwargs)
|
231 |
+
if pretrained:
|
232 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
233 |
+
progress=progress)
|
234 |
+
model.load_state_dict(state_dict)
|
235 |
+
return model
|
236 |
+
|
237 |
+
|
238 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
239 |
+
r"""ResNet-18 model from
|
240 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
241 |
+
|
242 |
+
Args:
|
243 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
244 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
245 |
+
"""
|
246 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
247 |
+
**kwargs)
|
248 |
+
|
249 |
+
|
250 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
251 |
+
r"""ResNet-34 model from
|
252 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
253 |
+
|
254 |
+
Args:
|
255 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
256 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
257 |
+
"""
|
258 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
259 |
+
**kwargs)
|
260 |
+
|
261 |
+
|
262 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
263 |
+
r"""ResNet-50 model from
|
264 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
265 |
+
|
266 |
+
Args:
|
267 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
268 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
269 |
+
"""
|
270 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
271 |
+
**kwargs)
|
272 |
+
|
273 |
+
|
274 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
275 |
+
r"""ResNet-101 model from
|
276 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
277 |
+
|
278 |
+
Args:
|
279 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
280 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
281 |
+
"""
|
282 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
283 |
+
**kwargs)
|
284 |
+
|
285 |
+
|
286 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
287 |
+
r"""ResNet-152 model from
|
288 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
289 |
+
|
290 |
+
Args:
|
291 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
292 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
293 |
+
"""
|
294 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
295 |
+
**kwargs)
|
296 |
+
|
297 |
+
|
298 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
299 |
+
r"""ResNeXt-50 32x4d model from
|
300 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
301 |
+
|
302 |
+
Args:
|
303 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
304 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
305 |
+
"""
|
306 |
+
kwargs['groups'] = 32
|
307 |
+
kwargs['width_per_group'] = 4
|
308 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
309 |
+
pretrained, progress, **kwargs)
|
310 |
+
|
311 |
+
|
312 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
313 |
+
r"""ResNeXt-101 32x8d model from
|
314 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
315 |
+
|
316 |
+
Args:
|
317 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
318 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
319 |
+
"""
|
320 |
+
kwargs['groups'] = 32
|
321 |
+
kwargs['width_per_group'] = 8
|
322 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
323 |
+
pretrained, progress, **kwargs)
|
324 |
+
|
325 |
+
|
326 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
327 |
+
r"""Wide ResNet-50-2 model from
|
328 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
329 |
+
|
330 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
331 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
332 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
333 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
337 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
338 |
+
"""
|
339 |
+
kwargs['width_per_group'] = 64 * 2
|
340 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
341 |
+
pretrained, progress, **kwargs)
|
342 |
+
|
343 |
+
|
344 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
345 |
+
r"""Wide ResNet-101-2 model from
|
346 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
347 |
+
|
348 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
349 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
350 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
351 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
355 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
356 |
+
"""
|
357 |
+
kwargs['width_per_group'] = 64 * 2
|
358 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
359 |
+
pretrained, progress, **kwargs)
|
modules/networks/faceshifter.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import os
|
6 |
+
import kornia
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
from modules.layers.faceshifter.layers import AEI_Net
|
10 |
+
from modules.layers.faceshifter.hear_layers import Hear_Net
|
11 |
+
from third_party.arcface import iresnet100, MouthNet
|
12 |
+
|
13 |
+
make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
|
14 |
+
|
15 |
+
|
16 |
+
class FSGenerator(nn.Module):
|
17 |
+
def __init__(self,
|
18 |
+
id_ckpt: str = None,
|
19 |
+
id_dim: int = 512,
|
20 |
+
mouth_net_param: dict = None,
|
21 |
+
in_size: int = 256,
|
22 |
+
finetune: bool = False,
|
23 |
+
downup: bool = False,
|
24 |
+
):
|
25 |
+
super(FSGenerator, self).__init__()
|
26 |
+
|
27 |
+
''' MouthNet '''
|
28 |
+
self.use_mouth_net = mouth_net_param.get('use')
|
29 |
+
self.mouth_feat_dim = 0
|
30 |
+
self.mouth_net = None
|
31 |
+
if self.use_mouth_net:
|
32 |
+
self.mouth_feat_dim = mouth_net_param.get('feature_dim')
|
33 |
+
self.mouth_crop_param = mouth_net_param.get('crop_param')
|
34 |
+
mouth_weight_path = make_abs_path(mouth_net_param.get('weight_path'))
|
35 |
+
self.mouth_net = MouthNet(
|
36 |
+
bisenet=None,
|
37 |
+
feature_dim=self.mouth_feat_dim,
|
38 |
+
crop_param=self.mouth_crop_param
|
39 |
+
)
|
40 |
+
self.mouth_net.load_backbone(mouth_weight_path)
|
41 |
+
print("[FaceShifter Generator] MouthNet loaded from %s" % mouth_weight_path)
|
42 |
+
self.mouth_net.eval()
|
43 |
+
self.mouth_net.requires_grad_(False)
|
44 |
+
|
45 |
+
self.G = AEI_Net(c_id=id_dim + self.mouth_feat_dim, finetune=finetune, downup=downup)
|
46 |
+
self.iresnet = iresnet100()
|
47 |
+
if not id_ckpt is None:
|
48 |
+
self.iresnet.load_state_dict(torch.load(id_ckpt, "cpu"))
|
49 |
+
else:
|
50 |
+
warnings.warn("Face ID backbone [%s] not found!" % id_ckpt)
|
51 |
+
raise FileNotFoundError("Face ID backbone [%s] not found!" % id_ckpt)
|
52 |
+
self.iresnet.eval()
|
53 |
+
self.register_buffer(
|
54 |
+
name="trans_matrix",
|
55 |
+
tensor=torch.tensor(
|
56 |
+
[
|
57 |
+
[
|
58 |
+
[1.07695457, -0.03625215, -1.56352194 * (in_size / 256)],
|
59 |
+
[0.03625215, 1.07695457, -5.32134629 * (in_size / 256)],
|
60 |
+
]
|
61 |
+
],
|
62 |
+
requires_grad=False,
|
63 |
+
).float(),
|
64 |
+
)
|
65 |
+
self.in_size = in_size
|
66 |
+
|
67 |
+
self.iresnet.requires_grad_(False)
|
68 |
+
|
69 |
+
def forward(self, source, target, infer=False):
|
70 |
+
with torch.no_grad():
|
71 |
+
''' 1. get id '''
|
72 |
+
if infer:
|
73 |
+
resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
|
74 |
+
id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)
|
75 |
+
else:
|
76 |
+
M = self.trans_matrix.repeat(source.size()[0], 1, 1)
|
77 |
+
source = kornia.geometry.transform.warp_affine(source, M, (self.in_size, self.in_size))
|
78 |
+
|
79 |
+
# import cv2
|
80 |
+
# from tricks import Trick
|
81 |
+
# cv2.imwrite('warpped_source.png', Trick.tensor_to_arr(source)[0, :, :, ::-1])
|
82 |
+
|
83 |
+
resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
|
84 |
+
id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)
|
85 |
+
|
86 |
+
''' 2. get mouth feature '''
|
87 |
+
if self.use_mouth_net:
|
88 |
+
w1, h1, w2, h2 = self.mouth_crop_param
|
89 |
+
mouth_input = resize_input[:, :, h1:h2, w1:w2] # 112->mouth
|
90 |
+
mouth_feat = self.mouth_net(mouth_input)
|
91 |
+
id_vector = torch.cat([id_vector, mouth_feat], dim=-1) # (B,dim_id+dim_mouth)
|
92 |
+
|
93 |
+
x, att = self.G(target, id_vector)
|
94 |
+
return x, id_vector, att
|
95 |
+
|
96 |
+
def get_recon(self):
|
97 |
+
return self.G.get_recon_tensor()
|
98 |
+
|
99 |
+
def get_att(self, x):
|
100 |
+
return self.G.get_attr(x)
|
101 |
+
|
102 |
+
|
103 |
+
class FSHearNet(nn.Module):
|
104 |
+
def __init__(self, aei_path: str):
|
105 |
+
super(FSHearNet, self).__init__()
|
106 |
+
''' Stage I. AEI_Net '''
|
107 |
+
self.aei = FSGenerator(
|
108 |
+
id_ckpt=make_abs_path("../../modules/third_party/arcface/weights/ms1mv3_arcface_r100_fp16/backbone.pth")
|
109 |
+
).requires_grad_(False)
|
110 |
+
print('Loading pre-trained AEI-Net from %s...' % aei_path)
|
111 |
+
self._load_pretrained_aei(aei_path)
|
112 |
+
print('Loaded.')
|
113 |
+
|
114 |
+
''' Stage II. HEAR_Net '''
|
115 |
+
self.hear = Hear_Net()
|
116 |
+
|
117 |
+
def _load_pretrained_aei(self, path: str):
|
118 |
+
if '.ckpt' in path:
|
119 |
+
from trainer.faceshifter.extract_ckpt import extract_generator
|
120 |
+
pth_folder = make_abs_path('../../trainer/faceshifter/extracted_ckpt')
|
121 |
+
pth_name = 'hear_tmp.pth'
|
122 |
+
assert '.pth' in pth_name
|
123 |
+
state_dict = extract_generator(load_path=path, path=os.path.join(pth_folder, pth_name))
|
124 |
+
self.aei.load_state_dict(state_dict, strict=False)
|
125 |
+
self.aei.eval()
|
126 |
+
elif '.pth' in path:
|
127 |
+
self.aei.load_state_dict(torch.load(path, "cpu"), strict=False)
|
128 |
+
self.aei.eval()
|
129 |
+
else:
|
130 |
+
raise FileNotFoundError('%s (.ckpt or .pth) not found.' % path)
|
131 |
+
|
132 |
+
def forward(self, source, target):
|
133 |
+
with torch.no_grad():
|
134 |
+
y_hat_st, _, _ = self.aei(source, target, infer=True)
|
135 |
+
y_hat_tt, _, _ = self.aei(target, target, infer=True)
|
136 |
+
delta_y_t = target - y_hat_tt
|
137 |
+
y_cat = torch.cat([y_hat_st, delta_y_t], dim=1) # (B,6,256,256)
|
138 |
+
|
139 |
+
y_st = self.hear(y_cat)
|
140 |
+
|
141 |
+
return y_st, y_hat_st # both (B,3,256,256)
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
|
146 |
+
source = torch.randn(8, 3, 512, 512)
|
147 |
+
target = torch.randn(8, 3, 512, 512)
|
148 |
+
net = FSGenerator(
|
149 |
+
id_ckpt="/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/checkpoints/"
|
150 |
+
"face_id/ms1mv3_arcface_r100_fp16_backbone.pth",
|
151 |
+
mouth_net_param={
|
152 |
+
'use': False
|
153 |
+
}
|
154 |
+
)
|
155 |
+
result, _, _ = net(source, target)
|
156 |
+
print('result:', result.shape)
|
157 |
+
|
158 |
+
# stage2 = FSHearNet(
|
159 |
+
# aei_path=make_abs_path("../../trainer/faceshifter/out/faceshifter_vanilla/epoch=32-step=509999.ckpt")
|
160 |
+
# )
|
161 |
+
# final_out, _ = stage2(source, target)
|
162 |
+
# print('final out:', final_out.shape)
|
modules/networks/simswap.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
#############################################################
|
4 |
+
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
|
5 |
+
# Created Date: Wednesday January 12th 2022
|
6 |
+
# Author: Chen Xuanhong
|
7 |
+
# Email: chenxuanhongzju@outlook.com
|
8 |
+
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
|
9 |
+
# Modified By: Chen Xuanhong
|
10 |
+
# Copyright (c) 2022 Shanghai Jiao Tong University
|
11 |
+
#############################################################
|
12 |
+
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from modules.layers.simswap.base_model import BaseModel
|
18 |
+
from modules.layers.simswap.fs_networks_fix import Generator_Adain_Upsample
|
19 |
+
|
20 |
+
from modules.layers.simswap.pg_modules.projected_discriminator import ProjectedDiscriminator
|
21 |
+
|
22 |
+
|
23 |
+
def compute_grad2(d_out, x_in):
|
24 |
+
batch_size = x_in.size(0)
|
25 |
+
grad_dout = torch.autograd.grad(
|
26 |
+
outputs=d_out.sum(), inputs=x_in,
|
27 |
+
create_graph=True, retain_graph=True, only_inputs=True
|
28 |
+
)[0]
|
29 |
+
grad_dout2 = grad_dout.pow(2)
|
30 |
+
assert(grad_dout2.size() == x_in.size())
|
31 |
+
reg = grad_dout2.view(batch_size, -1).sum(1)
|
32 |
+
return reg
|
33 |
+
|
34 |
+
|
35 |
+
class fsModel(BaseModel):
|
36 |
+
def name(self):
|
37 |
+
return 'fsModel'
|
38 |
+
|
39 |
+
def initialize(self, opt):
|
40 |
+
BaseModel.initialize(self, opt)
|
41 |
+
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
|
42 |
+
self.isTrain = opt.isTrain
|
43 |
+
|
44 |
+
# Generator network
|
45 |
+
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
|
46 |
+
self.netG.cuda()
|
47 |
+
|
48 |
+
# Id network
|
49 |
+
from third_party.arcface import iresnet100
|
50 |
+
netArc_pth = "/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/" \
|
51 |
+
"checkpoints/face_id/ms1mv3_arcface_r100_fp16_backbone.pth" #opt.Arc_path
|
52 |
+
self.netArc = iresnet100(pretrained=False, fp16=False)
|
53 |
+
self.netArc.load_state_dict(torch.load(netArc_pth, map_location="cpu"))
|
54 |
+
# netArc_checkpoint = opt.Arc_path
|
55 |
+
# netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
|
56 |
+
# self.netArc = netArc_checkpoint['model'].module
|
57 |
+
self.netArc = self.netArc.cuda()
|
58 |
+
self.netArc.eval()
|
59 |
+
self.netArc.requires_grad_(False)
|
60 |
+
if not self.isTrain:
|
61 |
+
pretrained_path = opt.checkpoints_dir
|
62 |
+
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
63 |
+
return
|
64 |
+
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
|
65 |
+
# self.netD.feature_network.requires_grad_(False)
|
66 |
+
self.netD.cuda()
|
67 |
+
|
68 |
+
|
69 |
+
if self.isTrain:
|
70 |
+
# define loss functions
|
71 |
+
self.criterionFeat = nn.L1Loss()
|
72 |
+
self.criterionRec = nn.L1Loss()
|
73 |
+
|
74 |
+
# initialize optimizers
|
75 |
+
# optimizer G
|
76 |
+
params = list(self.netG.parameters())
|
77 |
+
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
|
78 |
+
|
79 |
+
# optimizer D
|
80 |
+
params = list(self.netD.parameters())
|
81 |
+
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
|
82 |
+
|
83 |
+
# load networks
|
84 |
+
if opt.continue_train:
|
85 |
+
pretrained_path = '' if not self.isTrain else opt.load_pretrain
|
86 |
+
# print (pretrained_path)
|
87 |
+
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
88 |
+
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
89 |
+
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
|
90 |
+
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
|
91 |
+
torch.cuda.empty_cache()
|
92 |
+
|
93 |
+
def cosin_metric(self, x1, x2):
|
94 |
+
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
95 |
+
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
|
96 |
+
|
97 |
+
def save(self, which_epoch):
|
98 |
+
self.save_network(self.netG, 'G', which_epoch)
|
99 |
+
self.save_network(self.netD, 'D', which_epoch)
|
100 |
+
self.save_optim(self.optimizer_G, 'G', which_epoch)
|
101 |
+
self.save_optim(self.optimizer_D, 'D', which_epoch)
|
102 |
+
'''if self.gen_features:
|
103 |
+
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
|
104 |
+
|
105 |
+
def update_fixed_params(self):
|
106 |
+
raise ValueError('Not used')
|
107 |
+
# after fixing the global generator for a number of iterations, also start finetuning it
|
108 |
+
params = list(self.netG.parameters())
|
109 |
+
if self.gen_features:
|
110 |
+
params += list(self.netE.parameters())
|
111 |
+
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
|
112 |
+
if self.opt.verbose:
|
113 |
+
print('------------ Now also finetuning global generator -----------')
|
114 |
+
|
115 |
+
def update_learning_rate(self):
|
116 |
+
raise ValueError('Not used')
|
117 |
+
lrd = self.opt.lr / self.opt.niter_decay
|
118 |
+
lr = self.old_lr - lrd
|
119 |
+
for param_group in self.optimizer_D.param_groups:
|
120 |
+
param_group['lr'] = lr
|
121 |
+
for param_group in self.optimizer_G.param_groups:
|
122 |
+
param_group['lr'] = lr
|
123 |
+
if self.opt.verbose:
|
124 |
+
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
125 |
+
self.old_lr = lr
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
import os
|
130 |
+
import argparse
|
131 |
+
|
132 |
+
def str2bool(v):
|
133 |
+
return v.lower() in ('true')
|
134 |
+
|
135 |
+
|
136 |
+
class TrainOptions:
|
137 |
+
def __init__(self):
|
138 |
+
self.parser = argparse.ArgumentParser()
|
139 |
+
self.initialized = False
|
140 |
+
|
141 |
+
def initialize(self):
|
142 |
+
self.parser.add_argument('--name', type=str, default='simswap',
|
143 |
+
help='name of the experiment. It decides where to store samples and models')
|
144 |
+
self.parser.add_argument('--gpu_ids', default='0')
|
145 |
+
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints',
|
146 |
+
help='models are saved here')
|
147 |
+
self.parser.add_argument('--isTrain', type=str2bool, default='True')
|
148 |
+
|
149 |
+
# input/output sizes
|
150 |
+
self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
|
151 |
+
|
152 |
+
# for displays
|
153 |
+
self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')
|
154 |
+
|
155 |
+
# for training
|
156 |
+
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2",
|
157 |
+
help='path to the face swapping dataset')
|
158 |
+
self.parser.add_argument('--continue_train', type=str2bool, default='False',
|
159 |
+
help='continue training: load the latest model')
|
160 |
+
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test',
|
161 |
+
help='load the pretrained model from the specified location')
|
162 |
+
self.parser.add_argument('--which_epoch', type=str, default='10000',
|
163 |
+
help='which epoch to load? set to latest to use latest cached model')
|
164 |
+
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
165 |
+
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
|
166 |
+
self.parser.add_argument('--niter_decay', type=int, default=10000,
|
167 |
+
help='# of iter to linearly decay learning rate to zero')
|
168 |
+
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
|
169 |
+
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
|
170 |
+
self.parser.add_argument('--Gdeep', type=str2bool, default='False')
|
171 |
+
|
172 |
+
# for discriminators
|
173 |
+
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
|
174 |
+
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
|
175 |
+
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
|
176 |
+
|
177 |
+
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar',
|
178 |
+
help="run ONNX model via TRT")
|
179 |
+
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
|
180 |
+
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
|
181 |
+
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
|
182 |
+
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
|
183 |
+
|
184 |
+
self.isTrain = True
|
185 |
+
|
186 |
+
def parse(self, save=True):
|
187 |
+
if not self.initialized:
|
188 |
+
self.initialize()
|
189 |
+
self.opt = self.parser.parse_args()
|
190 |
+
self.opt.isTrain = self.isTrain # train or test
|
191 |
+
|
192 |
+
args = vars(self.opt)
|
193 |
+
|
194 |
+
print('------------ Options -------------')
|
195 |
+
for k, v in sorted(args.items()):
|
196 |
+
print('%s: %s' % (str(k), str(v)))
|
197 |
+
print('-------------- End ----------------')
|
198 |
+
|
199 |
+
# save to the disk
|
200 |
+
# if self.opt.isTrain:
|
201 |
+
# expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
|
202 |
+
# util.mkdirs(expr_dir)
|
203 |
+
# if save and not self.opt.continue_train:
|
204 |
+
# file_name = os.path.join(expr_dir, 'opt.txt')
|
205 |
+
# with open(file_name, 'wt') as opt_file:
|
206 |
+
# opt_file.write('------------ Options -------------\n')
|
207 |
+
# for k, v in sorted(args.items()):
|
208 |
+
# opt_file.write('%s: %s\n' % (str(k), str(v)))
|
209 |
+
# opt_file.write('-------------- End ----------------\n')
|
210 |
+
return self.opt
|
211 |
+
|
212 |
+
source = torch.randn(8, 3, 256, 256).cuda()
|
213 |
+
target = torch.randn(8, 3, 256, 256).cuda()
|
214 |
+
|
215 |
+
opt = TrainOptions().parse()
|
216 |
+
model = fsModel()
|
217 |
+
model.initialize(opt)
|
218 |
+
|
219 |
+
import torch.nn.functional as F
|
220 |
+
img_id_112 = F.interpolate(source, size=(112, 112), mode='bicubic')
|
221 |
+
latent_id = model.netArc(img_id_112)
|
222 |
+
latent_id = F.normalize(latent_id, p=2, dim=1)
|
223 |
+
|
224 |
+
img_fake = model.netG(target, latent_id)
|
225 |
+
gen_logits, _ = model.netD(img_fake.detach(), None)
|
226 |
+
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
227 |
+
|
228 |
+
real_logits, _ = model.netD(source, None)
|
229 |
+
|
230 |
+
print('img_fake:', img_fake.shape, 'real_logits:', real_logits.shape)
|
third_party/arcface/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from third_party.arcface.iresnet import iresnet18, iresnet34, iresnet50, iresnet100
|
2 |
+
from third_party.arcface.mouth_net import MouthNet
|
third_party/arcface/dataloaderx.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A copy from https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/dataset.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import queue as Queue
|
6 |
+
import threading
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
|
11 |
+
|
12 |
+
class BackgroundGenerator(threading.Thread):
|
13 |
+
def __init__(self, generator, local_rank, max_prefetch=6):
|
14 |
+
super(BackgroundGenerator, self).__init__()
|
15 |
+
self.queue = Queue.Queue(max_prefetch)
|
16 |
+
self.generator = generator
|
17 |
+
self.local_rank = local_rank
|
18 |
+
self.daemon = True
|
19 |
+
self.start()
|
20 |
+
|
21 |
+
def run(self):
|
22 |
+
torch.cuda.set_device(self.local_rank)
|
23 |
+
for item in self.generator:
|
24 |
+
self.queue.put(item)
|
25 |
+
self.queue.put(None)
|
26 |
+
|
27 |
+
def next(self):
|
28 |
+
next_item = self.queue.get()
|
29 |
+
if next_item is None:
|
30 |
+
raise StopIteration
|
31 |
+
return next_item
|
32 |
+
|
33 |
+
def __next__(self):
|
34 |
+
return self.next()
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class DataLoaderX(DataLoader):
|
41 |
+
def __init__(self, local_rank, **kwargs):
|
42 |
+
super(DataLoaderX, self).__init__(**kwargs)
|
43 |
+
self.stream = torch.cuda.Stream(local_rank)
|
44 |
+
self.local_rank = local_rank
|
45 |
+
|
46 |
+
def __iter__(self):
|
47 |
+
self.iter = super(DataLoaderX, self).__iter__()
|
48 |
+
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
49 |
+
self.preload()
|
50 |
+
return self
|
51 |
+
|
52 |
+
def preload(self):
|
53 |
+
self.batch = next(self.iter, None)
|
54 |
+
if self.batch is None:
|
55 |
+
return None
|
56 |
+
with torch.cuda.stream(self.stream):
|
57 |
+
for k in range(len(self.batch)):
|
58 |
+
self.batch[k] = self.batch[k].to(device=self.local_rank,
|
59 |
+
non_blocking=True)
|
60 |
+
|
61 |
+
def __next__(self):
|
62 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
63 |
+
batch = self.batch
|
64 |
+
if batch is None:
|
65 |
+
raise StopIteration
|
66 |
+
self.preload()
|
67 |
+
return batch
|
third_party/arcface/iresnet.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
__all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
8 |
+
"""3x3 convolution with padding"""
|
9 |
+
return nn.Conv2d(
|
10 |
+
in_planes,
|
11 |
+
out_planes,
|
12 |
+
kernel_size=3,
|
13 |
+
stride=stride,
|
14 |
+
padding=dilation,
|
15 |
+
groups=groups,
|
16 |
+
bias=False,
|
17 |
+
dilation=dilation,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
22 |
+
"""1x1 convolution"""
|
23 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
24 |
+
|
25 |
+
|
26 |
+
class IBasicBlock(nn.Module):
|
27 |
+
expansion = 1
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
inplanes,
|
32 |
+
planes,
|
33 |
+
stride=1,
|
34 |
+
downsample=None,
|
35 |
+
groups=1,
|
36 |
+
base_width=64,
|
37 |
+
dilation=1,
|
38 |
+
):
|
39 |
+
super(IBasicBlock, self).__init__()
|
40 |
+
if groups != 1 or base_width != 64:
|
41 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
42 |
+
if dilation > 1:
|
43 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
44 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
45 |
+
self.conv1 = conv3x3(inplanes, planes)
|
46 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
47 |
+
self.prelu = nn.PReLU(planes)
|
48 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
49 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
50 |
+
self.downsample = downsample
|
51 |
+
self.stride = stride
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
identity = x
|
55 |
+
out = self.bn1(x)
|
56 |
+
out = self.conv1(out)
|
57 |
+
out = self.bn2(out)
|
58 |
+
out = self.prelu(out)
|
59 |
+
out = self.conv2(out)
|
60 |
+
out = self.bn3(out)
|
61 |
+
if self.downsample is not None:
|
62 |
+
identity = self.downsample(x)
|
63 |
+
out += identity
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class IResNet(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
block,
|
71 |
+
layers,
|
72 |
+
dropout=0,
|
73 |
+
num_features=512,
|
74 |
+
zero_init_residual=False,
|
75 |
+
groups=1,
|
76 |
+
width_per_group=64,
|
77 |
+
replace_stride_with_dilation=None,
|
78 |
+
fp16=False,
|
79 |
+
fc_scale = 7 * 7,
|
80 |
+
):
|
81 |
+
super(IResNet, self).__init__()
|
82 |
+
self.fp16 = fp16
|
83 |
+
self.inplanes = 64
|
84 |
+
self.dilation = 1
|
85 |
+
self.fc_scale = fc_scale
|
86 |
+
if replace_stride_with_dilation is None:
|
87 |
+
replace_stride_with_dilation = [False, False, False]
|
88 |
+
if len(replace_stride_with_dilation) != 3:
|
89 |
+
raise ValueError(
|
90 |
+
"replace_stride_with_dilation should be None "
|
91 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
92 |
+
)
|
93 |
+
self.groups = groups
|
94 |
+
self.base_width = width_per_group
|
95 |
+
self.conv1 = nn.Conv2d(
|
96 |
+
3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
|
97 |
+
)
|
98 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
99 |
+
self.prelu = nn.PReLU(self.inplanes)
|
100 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
101 |
+
self.layer2 = self._make_layer(
|
102 |
+
block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
|
103 |
+
)
|
104 |
+
self.layer3 = self._make_layer(
|
105 |
+
block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
|
106 |
+
)
|
107 |
+
self.layer4 = self._make_layer(
|
108 |
+
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
|
109 |
+
)
|
110 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
111 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
112 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
113 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
114 |
+
nn.init.constant_(self.features.weight, 1.0)
|
115 |
+
self.features.weight.requires_grad = False
|
116 |
+
|
117 |
+
for m in self.modules():
|
118 |
+
if isinstance(m, nn.Conv2d):
|
119 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
120 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
121 |
+
nn.init.constant_(m.weight, 1)
|
122 |
+
nn.init.constant_(m.bias, 0)
|
123 |
+
|
124 |
+
if zero_init_residual:
|
125 |
+
for m in self.modules():
|
126 |
+
if isinstance(m, IBasicBlock):
|
127 |
+
nn.init.constant_(m.bn2.weight, 0)
|
128 |
+
|
129 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
130 |
+
downsample = None
|
131 |
+
previous_dilation = self.dilation
|
132 |
+
if dilate:
|
133 |
+
self.dilation *= stride
|
134 |
+
stride = 1
|
135 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
136 |
+
downsample = nn.Sequential(
|
137 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
138 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05,),
|
139 |
+
)
|
140 |
+
layers = []
|
141 |
+
layers.append(
|
142 |
+
block(
|
143 |
+
self.inplanes,
|
144 |
+
planes,
|
145 |
+
stride,
|
146 |
+
downsample,
|
147 |
+
self.groups,
|
148 |
+
self.base_width,
|
149 |
+
previous_dilation,
|
150 |
+
)
|
151 |
+
)
|
152 |
+
self.inplanes = planes * block.expansion
|
153 |
+
for _ in range(1, blocks):
|
154 |
+
layers.append(
|
155 |
+
block(
|
156 |
+
self.inplanes,
|
157 |
+
planes,
|
158 |
+
groups=self.groups,
|
159 |
+
base_width=self.base_width,
|
160 |
+
dilation=self.dilation,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
return nn.Sequential(*layers)
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
with torch.cuda.amp.autocast(self.fp16):
|
168 |
+
x = self.conv1(x)
|
169 |
+
x = self.bn1(x)
|
170 |
+
x = self.prelu(x)
|
171 |
+
x = self.layer1(x)
|
172 |
+
x = self.layer2(x)
|
173 |
+
x = self.layer3(x)
|
174 |
+
x = self.layer4(x)
|
175 |
+
x = self.bn2(x)
|
176 |
+
# print(x.shape)
|
177 |
+
x = torch.flatten(x, 1)
|
178 |
+
x = self.dropout(x)
|
179 |
+
x = self.fc(x.float() if self.fp16 else x)
|
180 |
+
x = self.features(x)
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
185 |
+
model = IResNet(block, layers, **kwargs)
|
186 |
+
if pretrained:
|
187 |
+
model_dir = {
|
188 |
+
'iresnet18': './weights/r18-backbone.pth',
|
189 |
+
'iresnet34': './weights/r34-backbone.pth',
|
190 |
+
'iresnet50': './weights/r50-backbone.pth',
|
191 |
+
'iresnet100': './weights/r100-backbone.pth',
|
192 |
+
}
|
193 |
+
pre_trained_weights = torch.load(model_dir[arch], map_location=torch.device('cpu'))
|
194 |
+
|
195 |
+
tmp_dict = {}
|
196 |
+
for key in pre_trained_weights:
|
197 |
+
# if 'features' in key or 'fc' in key:
|
198 |
+
# print('skip %s' % key)
|
199 |
+
# continue
|
200 |
+
tmp_dict[key] = pre_trained_weights[key]
|
201 |
+
|
202 |
+
# get 'iresnet' model layers which don't exist in 'arcxx' and insert to tmp
|
203 |
+
model_dict = model.state_dict()
|
204 |
+
for key in model_dict:
|
205 |
+
if key not in tmp_dict:
|
206 |
+
tmp_dict[key] = model_dict[key]
|
207 |
+
|
208 |
+
model.load_state_dict(tmp_dict, strict=False)
|
209 |
+
print("load pre-trained iresnet from %s" % model_dir[arch])
|
210 |
+
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
215 |
+
return _iresnet(
|
216 |
+
"iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
221 |
+
return _iresnet(
|
222 |
+
"iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
227 |
+
return _iresnet(
|
228 |
+
"iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
233 |
+
return _iresnet(
|
234 |
+
"iresnet100", IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs
|
235 |
+
)
|
236 |
+
|
237 |
+
|
238 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
239 |
+
return _iresnet(
|
240 |
+
"iresnet200", IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
@torch.no_grad()
|
245 |
+
def identification(folder: str = './images', target_idx: int = 0):
|
246 |
+
import os
|
247 |
+
from PIL import Image
|
248 |
+
import torch
|
249 |
+
import torchvision.transforms as transforms
|
250 |
+
import torch.nn.functional as F
|
251 |
+
import kornia
|
252 |
+
import numpy as np
|
253 |
+
|
254 |
+
os.makedirs('crop', exist_ok=True)
|
255 |
+
img_list = os.listdir(folder)
|
256 |
+
img_list.sort()
|
257 |
+
n = len(img_list)
|
258 |
+
trans = transforms.Compose([
|
259 |
+
transforms.Resize(256),
|
260 |
+
transforms.CenterCrop(224),
|
261 |
+
transforms.ToTensor(),
|
262 |
+
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
263 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
264 |
+
])
|
265 |
+
trans_matrix = torch.tensor(
|
266 |
+
[[[1.07695457, -0.03625215, -1.56352194],
|
267 |
+
[0.03625215, 1.07695457, -5.32134629]]],
|
268 |
+
requires_grad=False).float().cuda()
|
269 |
+
|
270 |
+
fid_model = iresnet50(pretrained=True).cuda().eval()
|
271 |
+
|
272 |
+
def save_tensor_to_img(tensor: torch.Tensor, path: str, scale=255):
|
273 |
+
tensor = tensor.permute(0, 2, 3, 1)[0] # in [0,1]
|
274 |
+
tensor = tensor.clamp(0, 1)
|
275 |
+
tensor = tensor * scale
|
276 |
+
tensor_np = tensor.cpu().numpy().astype(np.uint8)
|
277 |
+
if tensor_np.shape[-1] == 1: # channel dim
|
278 |
+
tensor_np = tensor_np.repeat(3, axis=-1)
|
279 |
+
tensor_img = Image.fromarray(tensor_np)
|
280 |
+
tensor_img.save(path)
|
281 |
+
|
282 |
+
feats = torch.zeros((n, 512), dtype=torch.float32).cuda()
|
283 |
+
for idx, img_path in enumerate(img_list):
|
284 |
+
img_pil = Image.open(os.path.join(folder, img_path)).convert('RGB')
|
285 |
+
img_tensor = trans(img_pil).unsqueeze(0).cuda()
|
286 |
+
|
287 |
+
# img_tensor = kornia.geometry.transform.warp_affine(img_tensor, trans_matrix, (256, 256))
|
288 |
+
save_tensor_to_img(img_tensor / 2 + 0.5, path=os.path.join('./crop', img_path))
|
289 |
+
img_tensor = F.interpolate(img_tensor, size=112, mode="bilinear", align_corners=True) # to 112
|
290 |
+
|
291 |
+
feat = fid_model(img_tensor)
|
292 |
+
feats[idx] = feat
|
293 |
+
|
294 |
+
target_feat = feats[target_idx].unsqueeze(0)
|
295 |
+
cosine_sim = F.cosine_similarity(target_feat, feats, 1)
|
296 |
+
print(cosine_sim.shape)
|
297 |
+
|
298 |
+
print('====== similarity with %s ======' % img_list[target_idx])
|
299 |
+
for idx in range(n):
|
300 |
+
print('[%d] %s = %.2f' % (idx, img_list[idx], float(cosine_sim[idx].cpu())))
|
301 |
+
|
302 |
+
|
303 |
+
if __name__ == '__main__':
|
304 |
+
import argparse
|
305 |
+
|
306 |
+
parser = argparse.ArgumentParser(description="arcface")
|
307 |
+
parser.add_argument("-i", "--target_idx", type=int, default=0)
|
308 |
+
args = parser.parse_args()
|
309 |
+
|
310 |
+
identification(target_idx=args.target_idx)
|
311 |
+
|
third_party/arcface/load_dataset.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numbers
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import mxnet as mx
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils import data
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import PIL.Image as Image
|
12 |
+
|
13 |
+
|
14 |
+
""" Original mxnet dataset
|
15 |
+
"""
|
16 |
+
class MXFaceDataset(data.Dataset):
|
17 |
+
def __init__(self, root_dir, crop_param=(0, 0, 112, 112)):
|
18 |
+
super(MXFaceDataset, self,).__init__()
|
19 |
+
self.transform = transforms.Compose([
|
20 |
+
# transforms.ToPILImage(),
|
21 |
+
transforms.RandomHorizontalFlip(),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
24 |
+
])
|
25 |
+
self.root_dir = root_dir
|
26 |
+
self.crop_param = crop_param
|
27 |
+
path_imgrec = os.path.join(root_dir, 'train.rec')
|
28 |
+
path_imgidx = os.path.join(root_dir, 'train.idx')
|
29 |
+
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
|
30 |
+
s = self.imgrec.read_idx(0)
|
31 |
+
header, _ = mx.recordio.unpack(s)
|
32 |
+
if header.flag > 0:
|
33 |
+
self.header0 = (int(header.label[0]), int(header.label[1]))
|
34 |
+
self.imgidx = np.array(range(1, int(header.label[0])))
|
35 |
+
else:
|
36 |
+
self.imgidx = np.array(list(self.imgrec.keys))
|
37 |
+
|
38 |
+
def __getitem__(self, index):
|
39 |
+
idx = self.imgidx[index]
|
40 |
+
s = self.imgrec.read_idx(idx)
|
41 |
+
header, img = mx.recordio.unpack(s)
|
42 |
+
label = header.label
|
43 |
+
if not isinstance(label, numbers.Number):
|
44 |
+
label = label[0]
|
45 |
+
label = torch.tensor(label, dtype=torch.long)
|
46 |
+
sample = mx.image.imdecode(img).asnumpy()
|
47 |
+
if self.transform is not None:
|
48 |
+
sample: Image = transforms.ToPILImage()(sample)
|
49 |
+
sample = sample.crop(self.crop_param)
|
50 |
+
sample = self.transform(sample)
|
51 |
+
return sample, label
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.imgidx)
|
55 |
+
|
56 |
+
|
57 |
+
""" MXNet binary dataset reader.
|
58 |
+
Refer to https://github.com/deepinsight/insightface.
|
59 |
+
"""
|
60 |
+
import pickle
|
61 |
+
from typing import List
|
62 |
+
from mxnet import ndarray as nd
|
63 |
+
class ReadMXNet(object):
|
64 |
+
def __init__(self, val_targets, rec_prefix, image_size=(112, 112)):
|
65 |
+
self.ver_list: List[object] = []
|
66 |
+
self.ver_name_list: List[str] = []
|
67 |
+
self.rec_prefix = rec_prefix
|
68 |
+
self.val_targets = val_targets
|
69 |
+
|
70 |
+
def init_dataset(self, val_targets, data_dir, image_size):
|
71 |
+
for name in val_targets:
|
72 |
+
path = os.path.join(data_dir, name + ".bin")
|
73 |
+
if os.path.exists(path):
|
74 |
+
data_set = self.load_bin(path, image_size)
|
75 |
+
self.ver_list.append(data_set)
|
76 |
+
self.ver_name_list.append(name)
|
77 |
+
|
78 |
+
def load_bin(self, path, image_size):
|
79 |
+
try:
|
80 |
+
with open(path, 'rb') as f:
|
81 |
+
bins, issame_list = pickle.load(f) # py2
|
82 |
+
except UnicodeDecodeError as e:
|
83 |
+
with open(path, 'rb') as f:
|
84 |
+
bins, issame_list = pickle.load(f, encoding='bytes') # py3
|
85 |
+
data_list = []
|
86 |
+
# for flip in [0, 1]:
|
87 |
+
# data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
|
88 |
+
# data_list.append(data)
|
89 |
+
for idx in range(len(issame_list) * 2):
|
90 |
+
_bin = bins[idx]
|
91 |
+
img = mx.image.imdecode(_bin)
|
92 |
+
if img.shape[1] != image_size[0]:
|
93 |
+
img = mx.image.resize_short(img, image_size[0])
|
94 |
+
img = nd.transpose(img, axes=(2, 0, 1)) # (C, H, W)
|
95 |
+
|
96 |
+
img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
|
97 |
+
import PIL.Image as Image
|
98 |
+
fig = Image.fromarray(img.asnumpy(), mode='RGB')
|
99 |
+
data_list.append(fig)
|
100 |
+
# data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
|
101 |
+
if idx % 1000 == 0:
|
102 |
+
print('loading bin', idx)
|
103 |
+
|
104 |
+
# # save img to '/home/yuange/dataset/LFW/rgb-arcface'
|
105 |
+
# img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
|
106 |
+
# # save_name = 'ind_' + str(idx) + '.bmp'
|
107 |
+
# # import os
|
108 |
+
# # save_name = os.path.join('/home/yuange/dataset/LFW/rgb-arcface', save_name)
|
109 |
+
# import PIL.Image as Image
|
110 |
+
# fig = Image.fromarray(img.asnumpy(), mode='RGB')
|
111 |
+
# # fig.save(save_name)
|
112 |
+
|
113 |
+
print('load finished', len(data_list))
|
114 |
+
return data_list, issame_list
|
115 |
+
|
116 |
+
|
117 |
+
"""
|
118 |
+
Evaluation Benchmark
|
119 |
+
"""
|
120 |
+
class EvalDataset(data.Dataset):
|
121 |
+
def __init__(self,
|
122 |
+
target: str = 'lfw',
|
123 |
+
rec_folder: str = '',
|
124 |
+
transform = None,
|
125 |
+
crop_param = (0, 0, 112, 112)
|
126 |
+
):
|
127 |
+
print("=> Pre-loading images ...")
|
128 |
+
self.target = target
|
129 |
+
self.rec_folder = rec_folder
|
130 |
+
mx_reader = ReadMXNet(target, rec_folder)
|
131 |
+
path = os.path.join(rec_folder, target + ".bin")
|
132 |
+
all_img, issame_list = mx_reader.load_bin(path, (112, 112))
|
133 |
+
self.all_img = all_img
|
134 |
+
self.issame_list = []
|
135 |
+
for i in range(len(issame_list)):
|
136 |
+
flag = 0 if issame_list[i] else 1 # 0:is same
|
137 |
+
self.issame_list.append(flag)
|
138 |
+
|
139 |
+
self.transform = transform
|
140 |
+
if self.transform is None:
|
141 |
+
self.transform = transforms.Compose([
|
142 |
+
transforms.ToTensor(),
|
143 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
144 |
+
])
|
145 |
+
self.crop_param = crop_param
|
146 |
+
|
147 |
+
def __getitem__(self, index):
|
148 |
+
img1 = self.all_img[index * 2]
|
149 |
+
img2 = self.all_img[index * 2 + 1]
|
150 |
+
same = self.issame_list[index]
|
151 |
+
|
152 |
+
save_index = 11
|
153 |
+
if index == save_index:
|
154 |
+
img1.save('img1_ori.jpg')
|
155 |
+
img2.save('img2_ori.jpg')
|
156 |
+
|
157 |
+
img1 = img1.crop(self.crop_param)
|
158 |
+
img2 = img2.crop(self.crop_param)
|
159 |
+
if index == save_index:
|
160 |
+
img1.save('img1_crop.jpg')
|
161 |
+
img2.save('img2_crop.jpg')
|
162 |
+
|
163 |
+
img1 = self.transform(img1)
|
164 |
+
img2 = self.transform(img2)
|
165 |
+
|
166 |
+
return img1, img2, same
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return len(self.issame_list)
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
|
174 |
+
import PIL.Image as Image
|
175 |
+
import time
|
176 |
+
|
177 |
+
np.random.seed(1)
|
178 |
+
torch.manual_seed(1)
|
179 |
+
torch.cuda.manual_seed(1)
|
180 |
+
torch.cuda.manual_seed_all(1)
|
181 |
+
mx.random.seed(1)
|
182 |
+
|
183 |
+
is_gray = False
|
184 |
+
|
185 |
+
train_set = FaceByRandOccMask(
|
186 |
+
root_dir='/tmp/train_tmp/casia',
|
187 |
+
local_rank=0,
|
188 |
+
use_norm=True,
|
189 |
+
is_gray=is_gray,
|
190 |
+
)
|
191 |
+
start = time.time()
|
192 |
+
for idx in range(100):
|
193 |
+
face, mask, label = train_set.__getitem__(idx)
|
194 |
+
if idx < 15:
|
195 |
+
face = ((face + 1) * 128).numpy().astype(np.uint8)
|
196 |
+
face = np.transpose(face, (1, 2, 0))
|
197 |
+
if is_gray:
|
198 |
+
face = Image.fromarray(face[:, :, 0], mode='L')
|
199 |
+
else:
|
200 |
+
face = Image.fromarray(face, mode='RGB')
|
201 |
+
face.save('face_{}.jpg'.format(idx))
|
202 |
+
print('time cost: %d ms' % (int((time.time() - start) * 1000)))
|
third_party/arcface/margin_loss.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from torch.nn import Parameter
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
__all__ = ['Softmax', 'AMCosFace', 'AMArcFace', ]
|
10 |
+
|
11 |
+
|
12 |
+
MIN_NUM_PATCHES = 16
|
13 |
+
|
14 |
+
|
15 |
+
""" All losses can run in 'torch.distributed.DistributedDataParallel'.
|
16 |
+
"""
|
17 |
+
|
18 |
+
class Softmax(nn.Module):
|
19 |
+
r"""Implementation of Softmax (normal classification head):
|
20 |
+
Args:
|
21 |
+
in_features: dimension (d_in) of input feature (B, d_in)
|
22 |
+
out_features: dimension (d_out) of output feature (B, d_out)
|
23 |
+
device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
|
24 |
+
if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
|
25 |
+
"""
|
26 |
+
def __init__(self,
|
27 |
+
in_features: int,
|
28 |
+
out_features: int,
|
29 |
+
device_id,
|
30 |
+
):
|
31 |
+
super(Softmax, self).__init__()
|
32 |
+
self.in_features = in_features
|
33 |
+
self.out_features = out_features
|
34 |
+
self.device_id = device_id
|
35 |
+
|
36 |
+
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
37 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
38 |
+
nn.init.xavier_uniform_(self.weight)
|
39 |
+
nn.init.zeros_(self.bias)
|
40 |
+
|
41 |
+
def forward(self, embedding, label):
|
42 |
+
"""
|
43 |
+
:param embedding: learned face representation
|
44 |
+
:param label:
|
45 |
+
- label >= 0: ground truth identity
|
46 |
+
- label = -1: invalid identity for this GPU (refer to 'PartialFC')
|
47 |
+
+ Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
48 |
+
:return:
|
49 |
+
"""
|
50 |
+
if self.device_id is None:
|
51 |
+
""" Regular linear layer.
|
52 |
+
"""
|
53 |
+
out = F.linear(embedding, self.weight, self.bias)
|
54 |
+
else:
|
55 |
+
raise ValueError('DataParallel is not implemented yet.')
|
56 |
+
x = input
|
57 |
+
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
|
58 |
+
sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0)
|
59 |
+
temp_x = x.cuda(self.device_id[0])
|
60 |
+
weight = sub_weights[0].cuda(self.device_id[0])
|
61 |
+
bias = sub_biases[0].cuda(self.device_id[0])
|
62 |
+
out = F.linear(temp_x, weight, bias)
|
63 |
+
for i in range(1, len(self.device_id)):
|
64 |
+
temp_x = x.cuda(self.device_id[i])
|
65 |
+
weight = sub_weights[i].cuda(self.device_id[i])
|
66 |
+
bias = sub_biases[i].cuda(self.device_id[i])
|
67 |
+
out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1)
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
""" Not Used """
|
72 |
+
class ArcFace(nn.Module):
|
73 |
+
r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
74 |
+
Args:
|
75 |
+
in_features: size of each input sample
|
76 |
+
out_features: size of each output sample
|
77 |
+
device_id: the ID of GPU where the model will be trained by model parallel.
|
78 |
+
if device_id=None, it will be trained on CPU without model parallel.
|
79 |
+
s: norm of input feature
|
80 |
+
m: margin
|
81 |
+
cos(theta+m)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, in_features, out_features, device_id, s=64.0, m=0.50, easy_margin=False):
|
85 |
+
super(ArcFace, self).__init__()
|
86 |
+
self.in_features = in_features
|
87 |
+
self.out_features = out_features
|
88 |
+
self.device_id = device_id
|
89 |
+
|
90 |
+
self.s = s
|
91 |
+
self.m = m
|
92 |
+
print('ArcFace, s=%.1f, m=%.2f' % (s, m))
|
93 |
+
|
94 |
+
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
95 |
+
nn.init.xavier_uniform_(self.weight)
|
96 |
+
|
97 |
+
self.easy_margin = easy_margin
|
98 |
+
self.cos_m = np.cos(m)
|
99 |
+
self.sin_m = np.sin(m)
|
100 |
+
self.th = np.cos(np.pi - m)
|
101 |
+
self.mm = np.sin(np.pi - m) * m
|
102 |
+
|
103 |
+
def forward(self, input, label):
|
104 |
+
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
105 |
+
if self.device_id == None:
|
106 |
+
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
107 |
+
else:
|
108 |
+
x = input
|
109 |
+
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
|
110 |
+
temp_x = x.cuda(self.device_id[0])
|
111 |
+
weight = sub_weights[0].cuda(self.device_id[0])
|
112 |
+
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
|
113 |
+
for i in range(1, len(self.device_id)):
|
114 |
+
temp_x = x.cuda(self.device_id[i])
|
115 |
+
weight = sub_weights[i].cuda(self.device_id[i])
|
116 |
+
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
|
117 |
+
dim=1)
|
118 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
119 |
+
phi = cosine * self.cos_m - sine * self.sin_m
|
120 |
+
if self.easy_margin:
|
121 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
122 |
+
else:
|
123 |
+
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
124 |
+
# --------------------------- convert label to one-hot ---------------------------
|
125 |
+
one_hot = torch.zeros(cosine.size())
|
126 |
+
if self.device_id != None:
|
127 |
+
one_hot = one_hot.cuda(self.device_id[0])
|
128 |
+
else:
|
129 |
+
one_hot = one_hot.cuda()
|
130 |
+
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
131 |
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
132 |
+
output = (one_hot * phi) + (
|
133 |
+
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
134 |
+
output *= self.s
|
135 |
+
|
136 |
+
return output
|
137 |
+
|
138 |
+
|
139 |
+
""" Not Used """
|
140 |
+
class CosFace(nn.Module):
|
141 |
+
r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):
|
142 |
+
Args:
|
143 |
+
in_features: size of each input sample
|
144 |
+
out_features: size of each output sample
|
145 |
+
device_id: the ID of GPU where the model will be trained by model parallel.
|
146 |
+
if device_id=None, it will be trained on CPU without model parallel.
|
147 |
+
s: norm of input feature
|
148 |
+
m: margin
|
149 |
+
cos(theta)-m
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self, in_features, out_features, device_id, s=64.0, m=0.4):
|
153 |
+
super(CosFace, self).__init__()
|
154 |
+
print('CosFace, s=%.1f, m=%.2f' % (s, m))
|
155 |
+
self.in_features = in_features
|
156 |
+
self.out_features = out_features
|
157 |
+
self.device_id = device_id
|
158 |
+
self.s = s
|
159 |
+
self.m = m
|
160 |
+
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
161 |
+
nn.init.xavier_uniform_(self.weight)
|
162 |
+
|
163 |
+
def forward(self, input, label):
|
164 |
+
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
165 |
+
|
166 |
+
if self.device_id == None:
|
167 |
+
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
168 |
+
else:
|
169 |
+
x = input
|
170 |
+
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
|
171 |
+
temp_x = x.cuda(self.device_id[0])
|
172 |
+
weight = sub_weights[0].cuda(self.device_id[0])
|
173 |
+
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
|
174 |
+
for i in range(1, len(self.device_id)):
|
175 |
+
temp_x = x.cuda(self.device_id[i])
|
176 |
+
weight = sub_weights[i].cuda(self.device_id[i])
|
177 |
+
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
|
178 |
+
dim=1)
|
179 |
+
phi = cosine - self.m
|
180 |
+
# --------------------------- convert label to one-hot ---------------------------
|
181 |
+
one_hot = torch.zeros(cosine.size()).cuda()
|
182 |
+
if self.device_id != None:
|
183 |
+
one_hot = one_hot.cuda(self.device_id[0])
|
184 |
+
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
|
185 |
+
one_hot.scatter_(1, label.cuda(self.device_id[0]).view(-1, 1).long(), 1)
|
186 |
+
else:
|
187 |
+
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
188 |
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
189 |
+
output = (one_hot * phi) + (
|
190 |
+
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
191 |
+
output *= self.s
|
192 |
+
|
193 |
+
return output
|
194 |
+
|
195 |
+
def __repr__(self):
|
196 |
+
return self.__class__.__name__ + '(' \
|
197 |
+
+ 'in_features = ' + str(self.in_features) \
|
198 |
+
+ ', out_features = ' + str(self.out_features) \
|
199 |
+
+ ', s = ' + str(self.s) \
|
200 |
+
+ ', m = ' + str(self.m) + ')'
|
201 |
+
|
202 |
+
|
203 |
+
class AMCosFace(nn.Module):
|
204 |
+
r"""Implementation of Adaptive Margin CosFace:
|
205 |
+
cos(theta)-m+k(theta-a)
|
206 |
+
When k is 0, AMCosFace degenerates into CosFace.
|
207 |
+
Args:
|
208 |
+
in_features: dimension (d_in) of input feature (B, d_in)
|
209 |
+
out_features: dimension (d_out) of output feature (B, d_out)
|
210 |
+
device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
|
211 |
+
if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
|
212 |
+
s: norm of input feature
|
213 |
+
m: margin
|
214 |
+
a: AM Loss
|
215 |
+
k: AM Loss
|
216 |
+
"""
|
217 |
+
def __init__(self,
|
218 |
+
in_features: int,
|
219 |
+
out_features: int,
|
220 |
+
device_id,
|
221 |
+
s: float = 64.0,
|
222 |
+
m: float = 0.4,
|
223 |
+
a: float = 1.2,
|
224 |
+
k: float = 0.1,
|
225 |
+
):
|
226 |
+
super(AMCosFace, self).__init__()
|
227 |
+
print('AMCosFace, s=%.1f, m=%.2f, a=%.2f, k=%.2f' % (s, m, a, k))
|
228 |
+
self.in_features = in_features
|
229 |
+
self.out_features = out_features
|
230 |
+
self.device_id = device_id
|
231 |
+
|
232 |
+
self.s = s
|
233 |
+
self.m = m
|
234 |
+
self.a = a
|
235 |
+
self.k = k
|
236 |
+
|
237 |
+
""" Weight Matrix W (d_out, d_in) """
|
238 |
+
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
239 |
+
nn.init.xavier_uniform_(self.weight)
|
240 |
+
|
241 |
+
def forward(self, embedding, label):
|
242 |
+
"""
|
243 |
+
:param embedding: learned face representation
|
244 |
+
:param label:
|
245 |
+
- label >= 0: ground truth identity
|
246 |
+
- label = -1: invalid identity for this GPU (refer to 'PartialFC')
|
247 |
+
+ Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
248 |
+
:return:
|
249 |
+
"""
|
250 |
+
if self.device_id is None:
|
251 |
+
""" - embedding: shape is (B, d_in)
|
252 |
+
- weight: shape is (d_out, d_in)
|
253 |
+
- cosine: shape is (B, d_out)
|
254 |
+
+ F.normalize is very important here.
|
255 |
+
"""
|
256 |
+
cosine = F.linear(F.normalize(embedding), F.normalize(self.weight)) # y = xA^T + b
|
257 |
+
else:
|
258 |
+
raise ValueError('DataParallel is not implemented yet.')
|
259 |
+
x = input
|
260 |
+
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
|
261 |
+
temp_x = x.cuda(self.device_id[0])
|
262 |
+
weight = sub_weights[0].cuda(self.device_id[0])
|
263 |
+
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
|
264 |
+
for i in range(1, len(self.device_id)):
|
265 |
+
temp_x = x.cuda(self.device_id[i])
|
266 |
+
weight = sub_weights[i].cuda(self.device_id[i])
|
267 |
+
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x),
|
268 |
+
F.normalize(weight)).cuda(self.device_id[0])),
|
269 |
+
dim=1)
|
270 |
+
|
271 |
+
""" - index: the index of valid identity in label, shape is (d_valid, )
|
272 |
+
+ torch.where() returns a tuple indicating the index of each dimension
|
273 |
+
+ Example: index = torch.tensor([1, 3, 4])
|
274 |
+
"""
|
275 |
+
index = torch.where(label != -1)[0]
|
276 |
+
|
277 |
+
""" - m_hot: one-hot tensor of margin m_2, shape is (d_valid, d_out)
|
278 |
+
+ torch.tensor.scatter_(dim, index, source) is usually used to generate ont-hot tensor
|
279 |
+
+ Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
280 |
+
index = torch.tensor([1, 3, 4]) # d_valid = index.shape[0] = 3
|
281 |
+
m_hot = torch.tensor([[0, 0, 0, 0, m, 0],
|
282 |
+
[0, 0, 0, 0, 0, m],
|
283 |
+
[0, 0, 0, m, 0, 0],
|
284 |
+
])
|
285 |
+
"""
|
286 |
+
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
287 |
+
m_hot.scatter_(1, label[index, None], self.m)
|
288 |
+
|
289 |
+
""" logit(theta) = cos(theta) - m_2 + k * (theta - a)
|
290 |
+
- theta = cosine.acos_()
|
291 |
+
+ Example: m_hot = torch.tensor([[0, 0, 0, 0, m-k(theta[0,4]-a), 0],
|
292 |
+
[0, 0, 0, 0, 0, m-k(theta[1,5]-a)],
|
293 |
+
[0, 0, 0, m-k(theta[2,3]-a), 0, 0],
|
294 |
+
])
|
295 |
+
"""
|
296 |
+
a = self.a
|
297 |
+
k = self.k
|
298 |
+
m_hot[range(0, index.size()[0]), label[index]] -= k * (cosine[index, label[index]].acos_() - a)
|
299 |
+
cosine[index] -= m_hot
|
300 |
+
|
301 |
+
""" Because we have used F.normalize, we should rescale the logit term by s.
|
302 |
+
"""
|
303 |
+
output = cosine * self.s
|
304 |
+
|
305 |
+
return output
|
306 |
+
|
307 |
+
def __repr__(self):
|
308 |
+
return self.__class__.__name__ + '(' \
|
309 |
+
+ 'in_features = ' + str(self.in_features) \
|
310 |
+
+ ', out_features = ' + str(self.out_features) \
|
311 |
+
+ ', s = ' + str(self.s) \
|
312 |
+
+ ', m = ' + str(self.m) \
|
313 |
+
+ ', a = ' + str(self.a) \
|
314 |
+
+ ', k = ' + str(self.k) \
|
315 |
+
+ ')'
|
316 |
+
|
317 |
+
|
318 |
+
class AMArcFace(nn.Module):
|
319 |
+
r"""Implementation of Adaptive Margin ArcFace:
|
320 |
+
cos(theta+m-k(theta-a))
|
321 |
+
When k is 0, AMArcFace degenerates into ArcFace.
|
322 |
+
Args:
|
323 |
+
in_features: dimension (d_in) of input feature (B, d_in)
|
324 |
+
out_features: dimension (d_out) of output feature (B, d_out)
|
325 |
+
device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
|
326 |
+
if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
|
327 |
+
s: norm of input feature
|
328 |
+
m: margin
|
329 |
+
a: AM Loss
|
330 |
+
k: AM Loss
|
331 |
+
"""
|
332 |
+
def __init__(self,
|
333 |
+
in_features: int,
|
334 |
+
out_features: int,
|
335 |
+
device_id,
|
336 |
+
s: float = 64.0,
|
337 |
+
m: float = 0.5,
|
338 |
+
a: float = 1.2,
|
339 |
+
k: float = 0.1,
|
340 |
+
):
|
341 |
+
super(AMArcFace, self).__init__()
|
342 |
+
print('AMArcFace, s=%.1f, m=%.2f, a=%.2f, k=%.2f' % (s, m, a, k))
|
343 |
+
self.in_features = in_features
|
344 |
+
self.out_features = out_features
|
345 |
+
self.device_id = device_id
|
346 |
+
|
347 |
+
self.s = s
|
348 |
+
self.m = m
|
349 |
+
self.a = a
|
350 |
+
self.k = k
|
351 |
+
|
352 |
+
""" Weight Matrix W (d_out, d_in) """
|
353 |
+
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
354 |
+
nn.init.xavier_uniform_(self.weight)
|
355 |
+
|
356 |
+
def forward(self, embedding, label):
|
357 |
+
"""
|
358 |
+
:param embedding: learned face representation
|
359 |
+
:param label:
|
360 |
+
- label >= 0: ground truth identity
|
361 |
+
- label = -1: invalid identity for this GPU (refer to 'PartialFC')
|
362 |
+
+ Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
363 |
+
:return:
|
364 |
+
"""
|
365 |
+
if self.device_id is None:
|
366 |
+
""" - embedding: shape is (B, d_in)
|
367 |
+
- weight: shape is (d_out, d_in)
|
368 |
+
- cosine: shape is (B, d_out)
|
369 |
+
+ F.normalize is very important here.
|
370 |
+
"""
|
371 |
+
cosine = F.linear(F.normalize(embedding), F.normalize(self.weight)) # y = xA^T + b
|
372 |
+
else:
|
373 |
+
raise ValueError('DataParallel is not implemented yet.')
|
374 |
+
x = input
|
375 |
+
sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
|
376 |
+
temp_x = x.cuda(self.device_id[0])
|
377 |
+
weight = sub_weights[0].cuda(self.device_id[0])
|
378 |
+
cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
|
379 |
+
for i in range(1, len(self.device_id)):
|
380 |
+
temp_x = x.cuda(self.device_id[i])
|
381 |
+
weight = sub_weights[i].cuda(self.device_id[i])
|
382 |
+
cosine = torch.cat((cosine, F.linear(F.normalize(temp_x),
|
383 |
+
F.normalize(weight)).cuda(self.device_id[0])),
|
384 |
+
dim=1)
|
385 |
+
|
386 |
+
""" - index: the index of valid identity in label, shape is (d_valid, )
|
387 |
+
+ torch.where() returns a tuple indicating the index of each dimension
|
388 |
+
+ Example: index = torch.tensor([1, 3, 4])
|
389 |
+
"""
|
390 |
+
index = torch.where(label != -1)[0]
|
391 |
+
|
392 |
+
""" - m_hot: one-hot tensor of margin m_2, shape is (d_valid, d_out)
|
393 |
+
+ torch.tensor.scatter_(dim, index, source) is usually used to generate ont-hot tensor
|
394 |
+
+ Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
395 |
+
index = torch.tensor([1, 3, 4]) # d_valid = index.shape[0] = 3
|
396 |
+
m_hot = torch.tensor([[0, 0, 0, 0, m, 0],
|
397 |
+
[0, 0, 0, 0, 0, m],
|
398 |
+
[0, 0, 0, m, 0, 0],
|
399 |
+
])
|
400 |
+
"""
|
401 |
+
m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
|
402 |
+
m_hot.scatter_(1, label[index, None], self.m)
|
403 |
+
|
404 |
+
""" logit(theta) = cos(theta) - m_2 + k * (theta - a)
|
405 |
+
- theta = cosine.acos_()
|
406 |
+
+ Example: m_hot = torch.tensor([[0, 0, 0, 0, m-k(theta[0,4]-a), 0],
|
407 |
+
[0, 0, 0, 0, 0, m-k(theta[1,5]-a)],
|
408 |
+
[0, 0, 0, m-k(theta[2,3]-a), 0, 0],
|
409 |
+
])
|
410 |
+
"""
|
411 |
+
a = self.a
|
412 |
+
k = self.k
|
413 |
+
m_hot[range(0, index.size()[0]), label[index]] -= k * (cosine[index, label[index]].acos_() - a)
|
414 |
+
|
415 |
+
cosine.acos_()
|
416 |
+
cosine[index] += m_hot
|
417 |
+
cosine.cos_().mul_(self.s)
|
418 |
+
return cosine
|
419 |
+
|
420 |
+
def __repr__(self):
|
421 |
+
return self.__class__.__name__ + '(' \
|
422 |
+
+ 'in_features = ' + str(self.in_features) \
|
423 |
+
+ ', out_features = ' + str(self.out_features) \
|
424 |
+
+ ', s = ' + str(self.s) \
|
425 |
+
+ ', m = ' + str(self.m) \
|
426 |
+
+ ', a = ' + str(self.a) \
|
427 |
+
+ ', k = ' + str(self.k) \
|
428 |
+
+ ')'
|
429 |
+
|
430 |
+
|
431 |
+
if __name__ == '__main__':
|
432 |
+
cosine = torch.randn(6, 8) / 100
|
433 |
+
cosine[0][2] = 0.3
|
434 |
+
cosine[1][4] = 0.4
|
435 |
+
cosine[2][6] = 0.5
|
436 |
+
cosine[3][5] = 0.6
|
437 |
+
cosine[4][3] = 0.7
|
438 |
+
cosine[5][0] = 0.8
|
439 |
+
label = torch.tensor([-1, 4, -1, 5, 3, -1])
|
440 |
+
|
441 |
+
# layer = AMCosFace(in_features=8,
|
442 |
+
# out_features=8,
|
443 |
+
# device_id=None,
|
444 |
+
# m=0.35, s=1.0,
|
445 |
+
# a=1.2, k=0.1)
|
446 |
+
|
447 |
+
# layer = Softmax(in_features=8,
|
448 |
+
# out_features=8,
|
449 |
+
# device_id=None)
|
450 |
+
|
451 |
+
layer = AMArcFace(in_features=8,
|
452 |
+
out_features=8,
|
453 |
+
device_id=None,
|
454 |
+
m=0.5, s=1.0,
|
455 |
+
a=1.2, k=0.1)
|
456 |
+
|
457 |
+
logit = layer(cosine, label)
|
458 |
+
logit = F.softmax(logit, dim=-1)
|
459 |
+
|
460 |
+
from utils.vis_tensor import plot_tensor
|
461 |
+
plot_tensor((cosine, logit),
|
462 |
+
('embedding', 'logit'),
|
463 |
+
'AMArc.jpg')
|
third_party/arcface/mouth_net.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
from third_party.arcface.iresnet import iresnet50, iresnet100
|
7 |
+
|
8 |
+
class MouthNet(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
bisenet: nn.Module,
|
11 |
+
feature_dim: int = 64,
|
12 |
+
crop_param: tuple = (0, 0, 112, 112),
|
13 |
+
iresnet_pretrained: bool = False,
|
14 |
+
):
|
15 |
+
super(MouthNet, self).__init__()
|
16 |
+
|
17 |
+
crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
|
18 |
+
fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
|
19 |
+
|
20 |
+
self.bisenet = bisenet
|
21 |
+
self.backbone = iresnet50(
|
22 |
+
pretrained=iresnet_pretrained,
|
23 |
+
num_features=feature_dim,
|
24 |
+
fp16=False,
|
25 |
+
fc_scale=fc_scale,
|
26 |
+
)
|
27 |
+
|
28 |
+
self.register_buffer(
|
29 |
+
name="vgg_mean",
|
30 |
+
tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False),
|
31 |
+
)
|
32 |
+
self.register_buffer(
|
33 |
+
name="vgg_std",
|
34 |
+
tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
# with torch.no_grad():
|
39 |
+
# x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True) # (B,1,H,W), in [0,1], 1:chosed
|
40 |
+
x_mouth_mask = 1
|
41 |
+
x_mouth = x * x_mouth_mask # (B,3,112,112)
|
42 |
+
mouth_feature = self.backbone(x_mouth)
|
43 |
+
return mouth_feature
|
44 |
+
|
45 |
+
def get_any_mask(self, img, par, normalized=False):
|
46 |
+
# [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
|
47 |
+
# 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip',
|
48 |
+
# 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
|
49 |
+
ori_size = img.size()[-1]
|
50 |
+
with torch.no_grad():
|
51 |
+
img = F.interpolate(img, size=512, mode="nearest", )
|
52 |
+
if not normalized:
|
53 |
+
img = img * 0.5 + 0.5
|
54 |
+
img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
|
55 |
+
out = self.bisenet(img)[0]
|
56 |
+
parsing = out.softmax(1).argmax(1)
|
57 |
+
mask = torch.zeros_like(parsing)
|
58 |
+
for p in par:
|
59 |
+
mask = mask + ((parsing == p).float())
|
60 |
+
mask = mask.unsqueeze(1)
|
61 |
+
mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
def save_backbone(self, path: str):
|
65 |
+
torch.save(self.backbone.state_dict(), path)
|
66 |
+
|
67 |
+
def load_backbone(self, path: str):
|
68 |
+
self.backbone.load_state_dict(torch.load(path))
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
from third_party.bisenet.bisenet import BiSeNet
|
73 |
+
|
74 |
+
bisenet = BiSeNet(19)
|
75 |
+
bisenet.load_state_dict(
|
76 |
+
torch.load(
|
77 |
+
"/gavin/datasets/hanbang/79999_iter.pth",
|
78 |
+
map_location="cpu",
|
79 |
+
)
|
80 |
+
)
|
81 |
+
bisenet.eval()
|
82 |
+
bisenet.requires_grad_(False)
|
83 |
+
|
84 |
+
crop_param = (28, 56, 84, 112)
|
85 |
+
|
86 |
+
import numpy as np
|
87 |
+
img = np.random.randn(112, 112, 3) * 225
|
88 |
+
from PIL import Image
|
89 |
+
img = Image.fromarray(img.astype(np.uint8))
|
90 |
+
img = img.crop(crop_param)
|
91 |
+
|
92 |
+
from torchvision import transforms
|
93 |
+
trans = transforms.ToTensor()
|
94 |
+
img = trans(img).unsqueeze(0)
|
95 |
+
img = img.repeat(3, 1, 1, 1)
|
96 |
+
print(img.shape)
|
97 |
+
|
98 |
+
net = MouthNet(
|
99 |
+
bisenet=bisenet,
|
100 |
+
feature_dim=64,
|
101 |
+
crop_param=crop_param
|
102 |
+
)
|
103 |
+
mouth_feat = net(img)
|
104 |
+
print(mouth_feat.shape)
|
105 |
+
|
106 |
+
import thop
|
107 |
+
|
108 |
+
crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
|
109 |
+
fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
|
110 |
+
backbone = iresnet100(
|
111 |
+
pretrained=False,
|
112 |
+
num_features=64,
|
113 |
+
fp16=False,
|
114 |
+
# fc_scale=fc_scale,
|
115 |
+
)
|
116 |
+
flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False)
|
117 |
+
print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
|
third_party/arcface/mouth_net_eval.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from third_party.arcface.mouth_net_pl import MouthNetPL
|
7 |
+
from third_party.arcface.mouth_net import MouthNet
|
8 |
+
|
9 |
+
|
10 |
+
class MouthTest(object):
|
11 |
+
def __init__(self):
|
12 |
+
self.dataset_len = 400
|
13 |
+
|
14 |
+
self.fixer_crop_param = (28, 56, 84, 112)
|
15 |
+
self.fixer_casia_model = MouthNet(
|
16 |
+
bisenet=None,
|
17 |
+
feature_dim=128,
|
18 |
+
crop_param=self.fixer_crop_param
|
19 |
+
).cuda()
|
20 |
+
fixer_path = "/gavin/code/FaceSwapping/modules/third_party/arcface/weights/fixer_net_casia_28_56_84_112.pth"
|
21 |
+
self.fixer_casia_model.load_backbone(fixer_path)
|
22 |
+
self.fixer_casia_model.eval()
|
23 |
+
self.fixer_t = np.zeros((self.dataset_len, 128), dtype=np.float32)
|
24 |
+
self.fixer_s = np.zeros_like(self.fixer_t, dtype=np.float32) # each embedding repeats 10 times in ffplus
|
25 |
+
self.fixer_r = np.zeros_like(self.fixer_t, dtype=np.float32)
|
26 |
+
print('Fixer model loaded.')
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
args = parser.parse_args()
|
33 |
+
args.val_targets = []
|
34 |
+
args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface"
|
35 |
+
|
36 |
+
fixer_net = MouthNetPL.load_from_checkpoint(
|
37 |
+
"/apdcephfs/share_1290939/gavinyuan/out/fixernet_casia/epoch=22-step=10999-v1.ckpt",
|
38 |
+
map_location='cpu', strict=False,
|
39 |
+
num_classes=10572,
|
40 |
+
batch_size=128,
|
41 |
+
dim_feature=128,
|
42 |
+
rec_folder=args.rec_folder,
|
43 |
+
header_type="AMCosFace",
|
44 |
+
crop=(28, 56, 84, 112),
|
45 |
+
)
|
46 |
+
|
47 |
+
lower_net_1 = MouthNetPL.load_from_checkpoint(
|
48 |
+
"/apdcephfs/share_1290939/gavinyuan/out/mouth_net_1/epoch=24-step=242999.ckpt",
|
49 |
+
map_location='cpu', strict=False,
|
50 |
+
num_classes=93431,
|
51 |
+
batch_size=128,
|
52 |
+
dim_feature=128,
|
53 |
+
rec_folder=args.rec_folder,
|
54 |
+
header_type="AMArcFace",
|
55 |
+
crop=(28, 56, 84, 112),
|
56 |
+
)
|
57 |
+
|
58 |
+
# test_net = fixer_net
|
59 |
+
test_net = lower_net_1
|
60 |
+
trainer = pl.Trainer(
|
61 |
+
logger=False,
|
62 |
+
gpus=1,
|
63 |
+
distributed_backend='dp',
|
64 |
+
benchmark=True,
|
65 |
+
)
|
66 |
+
trainer.test(test_net)
|
67 |
+
|
68 |
+
# print('Fixer model loading...')
|
69 |
+
# m_test = MouthTest()
|
third_party/arcface/mouth_net_pl.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import sklearn
|
11 |
+
from sklearn.metrics import roc_curve, auc
|
12 |
+
from scipy.spatial.distance import cdist
|
13 |
+
|
14 |
+
from third_party.arcface.mouth_net import MouthNet
|
15 |
+
from third_party.arcface.margin_loss import Softmax, AMArcFace, AMCosFace
|
16 |
+
from third_party.arcface.load_dataset import MXFaceDataset, EvalDataset
|
17 |
+
from third_party.bisenet.bisenet import BiSeNet
|
18 |
+
|
19 |
+
|
20 |
+
class MouthNetPL(pl.LightningModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
num_classes: int,
|
24 |
+
batch_size: int = 256,
|
25 |
+
dim_feature: int = 128,
|
26 |
+
header_type: str = 'AMArcFace',
|
27 |
+
header_params: tuple = (64.0, 0.5, 0.0, 0.0), # (s, m, a, k)
|
28 |
+
rec_folder: str = "/gavin/datasets/msml/ms1m-retinaface",
|
29 |
+
learning_rate: int = 0.1,
|
30 |
+
crop: tuple = (0, 0, 112, 112), # (w1,h1,w2,h2)
|
31 |
+
):
|
32 |
+
super(MouthNetPL, self).__init__()
|
33 |
+
|
34 |
+
# self.img_size = (112, 112)
|
35 |
+
|
36 |
+
''' mouth feature extractor '''
|
37 |
+
bisenet = BiSeNet(19)
|
38 |
+
bisenet.load_state_dict(
|
39 |
+
torch.load(
|
40 |
+
"/gavin/datasets/hanbang/79999_iter.pth",
|
41 |
+
map_location="cpu",
|
42 |
+
)
|
43 |
+
)
|
44 |
+
bisenet.eval()
|
45 |
+
bisenet.requires_grad_(False)
|
46 |
+
self.mouth_net = MouthNet(
|
47 |
+
bisenet=None,
|
48 |
+
feature_dim=dim_feature,
|
49 |
+
crop_param=crop,
|
50 |
+
iresnet_pretrained=False,
|
51 |
+
)
|
52 |
+
|
53 |
+
''' head & loss '''
|
54 |
+
self.automatic_optimization = False
|
55 |
+
self.dim_feature = dim_feature
|
56 |
+
self.num_classes = num_classes
|
57 |
+
self._prepare_header(header_type, header_params)
|
58 |
+
self.cls_criterion = torch.nn.CrossEntropyLoss()
|
59 |
+
self.learning_rate = learning_rate
|
60 |
+
|
61 |
+
''' dataset '''
|
62 |
+
assert os.path.exists(rec_folder)
|
63 |
+
self.rec_folder = rec_folder
|
64 |
+
self.batch_size = batch_size
|
65 |
+
self.crop_param = crop
|
66 |
+
|
67 |
+
''' validation '''
|
68 |
+
|
69 |
+
def _prepare_header(self, head_type, header_params):
|
70 |
+
dim_in = self.dim_feature
|
71 |
+
dim_out = self.num_classes
|
72 |
+
|
73 |
+
""" Get hyper-params of header """
|
74 |
+
s, m, a, k = header_params
|
75 |
+
|
76 |
+
""" Choose the header """
|
77 |
+
if 'Softmax' in head_type:
|
78 |
+
self.classification = Softmax(dim_in, dim_out, device_id=None)
|
79 |
+
elif 'AMCosFace' in head_type:
|
80 |
+
self.classification = AMCosFace(dim_in, dim_out,
|
81 |
+
device_id=None,
|
82 |
+
s=s, m=m,
|
83 |
+
a=a, k=k,
|
84 |
+
)
|
85 |
+
elif 'AMArcFace' in head_type:
|
86 |
+
self.classification = AMArcFace(dim_in, dim_out,
|
87 |
+
device_id=None,
|
88 |
+
s=s, m=m,
|
89 |
+
a=a, k=k,
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
raise ValueError('Header type error!')
|
93 |
+
|
94 |
+
def forward(self, x, label=None):
|
95 |
+
feat = self.mouth_net(x)
|
96 |
+
if self.training:
|
97 |
+
assert label is not None
|
98 |
+
cls = self.classification(feat, label)
|
99 |
+
return feat, cls
|
100 |
+
else:
|
101 |
+
return feat
|
102 |
+
|
103 |
+
def training_step(self, batch, batch_idx):
|
104 |
+
opt = self.optimizers(use_pl_optimizer=True)
|
105 |
+
img, label = batch
|
106 |
+
|
107 |
+
mouth_feat, final_cls = self(img, label)
|
108 |
+
|
109 |
+
cls_loss = self.cls_criterion(final_cls, label)
|
110 |
+
|
111 |
+
opt.zero_grad()
|
112 |
+
self.manual_backward(cls_loss)
|
113 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5, norm_type=2)
|
114 |
+
opt.step()
|
115 |
+
|
116 |
+
''' loss logging '''
|
117 |
+
self.logging_dict({"cls_loss": cls_loss}, prefix="train / ")
|
118 |
+
self.logging_lr()
|
119 |
+
if batch_idx % 50 == 0 and self.local_rank == 0:
|
120 |
+
print('loss=', cls_loss)
|
121 |
+
|
122 |
+
return cls_loss
|
123 |
+
|
124 |
+
def training_epoch_end(self, outputs):
|
125 |
+
sch = self.lr_schedulers()
|
126 |
+
sch.step()
|
127 |
+
|
128 |
+
lr = -1
|
129 |
+
opts = self.trainer.optimizers
|
130 |
+
for opt in opts:
|
131 |
+
for param_group in opt.param_groups:
|
132 |
+
lr = param_group["lr"]
|
133 |
+
break
|
134 |
+
print('learning rate changed to %.6f' % lr)
|
135 |
+
|
136 |
+
# def validation_step(self, batch, batch_idx):
|
137 |
+
# return self.test_step(batch, batch_idx)
|
138 |
+
#
|
139 |
+
# def validation_step_end(self, outputs):
|
140 |
+
# return self.test_step_end(outputs)
|
141 |
+
#
|
142 |
+
# def validation_epoch_end(self, outputs):
|
143 |
+
# return self.test_step_end(outputs)
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def save_tensor(tensor: torch.Tensor, path: str, b_idx: int = 0):
|
147 |
+
tensor = (tensor + 1.) * 127.5
|
148 |
+
img = tensor.permute(0, 2, 3, 1)[b_idx].cpu().numpy()
|
149 |
+
from PIL import Image
|
150 |
+
img_pil = Image.fromarray(img.astype(np.uint8))
|
151 |
+
img_pil.save(path)
|
152 |
+
|
153 |
+
def test_step(self, batch, batch_idx):
|
154 |
+
img1, img2, same = batch
|
155 |
+
feat1 = self.mouth_net(img1)
|
156 |
+
feat2 = self.mouth_net(img2)
|
157 |
+
return feat1, feat2, same
|
158 |
+
|
159 |
+
def test_step_end(self, outputs):
|
160 |
+
feat1, feat2, same = outputs
|
161 |
+
feat1 = feat1.cpu().numpy()
|
162 |
+
feat2 = feat2.cpu().numpy()
|
163 |
+
same = same.cpu().numpy()
|
164 |
+
|
165 |
+
feat1 = sklearn.preprocessing.normalize(feat1)
|
166 |
+
feat2 = sklearn.preprocessing.normalize(feat2)
|
167 |
+
|
168 |
+
predict_label = []
|
169 |
+
num = feat1.shape[0]
|
170 |
+
for i in range(num):
|
171 |
+
dis_cos = cdist(feat1[i, None], feat2[i, None], metric='cosine')
|
172 |
+
predict_label.append(dis_cos[0, 0])
|
173 |
+
predict_label = np.array(predict_label)
|
174 |
+
|
175 |
+
return {
|
176 |
+
"pred": predict_label,
|
177 |
+
"gt": same,
|
178 |
+
}
|
179 |
+
|
180 |
+
def test_epoch_end(self, outputs):
|
181 |
+
print(outputs)
|
182 |
+
pred, same = None, None
|
183 |
+
for batch_output in outputs:
|
184 |
+
if pred is None and same is None:
|
185 |
+
pred = batch_output["pred"]
|
186 |
+
same = batch_output["gt"]
|
187 |
+
else:
|
188 |
+
pred = np.concatenate([pred, batch_output["pred"]])
|
189 |
+
same = np.concatenate([same, batch_output["gt"]])
|
190 |
+
print(pred.shape, same.shape)
|
191 |
+
|
192 |
+
fpr, tpr, threshold = roc_curve(same, pred)
|
193 |
+
acc = tpr[np.argmin(np.abs(tpr - (1 - fpr)))] # choose proper threshold
|
194 |
+
print("=> verification finished, acc=%.4f" % (acc))
|
195 |
+
|
196 |
+
''' save pth '''
|
197 |
+
pth_path = "./weights/fixer_net_casia_%s.pth" % ('_'.join((str(x) for x in self.crop_param)))
|
198 |
+
self.mouth_net.save_backbone(pth_path)
|
199 |
+
print("=> model save to %s" % pth_path)
|
200 |
+
mouth_net = MouthNet(
|
201 |
+
bisenet=None,
|
202 |
+
feature_dim=self.dim_feature,
|
203 |
+
crop_param=self.crop_param
|
204 |
+
)
|
205 |
+
mouth_net.load_backbone(pth_path)
|
206 |
+
print("=> MouthNet pth checked")
|
207 |
+
|
208 |
+
return acc
|
209 |
+
|
210 |
+
def logging_dict(self, log_dict, prefix=None):
|
211 |
+
for key, val in log_dict.items():
|
212 |
+
if prefix is not None:
|
213 |
+
key = prefix + key
|
214 |
+
self.log(key, val)
|
215 |
+
|
216 |
+
def logging_lr(self):
|
217 |
+
opts = self.trainer.optimizers
|
218 |
+
for idx, opt in enumerate(opts):
|
219 |
+
lr = None
|
220 |
+
for param_group in opt.param_groups:
|
221 |
+
lr = param_group["lr"]
|
222 |
+
break
|
223 |
+
self.log(f"lr_{idx}", lr)
|
224 |
+
|
225 |
+
def configure_optimizers(self):
|
226 |
+
params = list(self.parameters())
|
227 |
+
learning_rate = self.learning_rate / 512 * self.batch_size * torch.cuda.device_count()
|
228 |
+
optimizer = torch.optim.SGD(params, lr=learning_rate,
|
229 |
+
momentum=0.9, weight_decay=5e-4)
|
230 |
+
print('lr is set as %.5f due to the global batch_size %d' % (learning_rate,
|
231 |
+
self.batch_size * torch.cuda.device_count()))
|
232 |
+
|
233 |
+
def lr_step_func(epoch):
|
234 |
+
return ((epoch + 1) / (4 + 1)) ** 2 if epoch < 0 else 0.1 ** len(
|
235 |
+
[m for m in [11, 17, 22] if m - 1 <= epoch]) # 0.1, 0.01, 0.001, 0.0001
|
236 |
+
scheduler= torch.optim.lr_scheduler.LambdaLR(
|
237 |
+
optimizer=optimizer, lr_lambda=lr_step_func)
|
238 |
+
|
239 |
+
return [optimizer], [scheduler]
|
240 |
+
|
241 |
+
def train_dataloader(self):
|
242 |
+
dataset = MXFaceDataset(
|
243 |
+
root_dir=self.rec_folder,
|
244 |
+
crop_param=self.crop_param,
|
245 |
+
)
|
246 |
+
train_loader = DataLoader(
|
247 |
+
dataset, self.batch_size, num_workers=24, shuffle=True, drop_last=True
|
248 |
+
)
|
249 |
+
return train_loader
|
250 |
+
|
251 |
+
def val_dataloader(self):
|
252 |
+
return self.test_dataloader()
|
253 |
+
|
254 |
+
def test_dataloader(self):
|
255 |
+
dataset = EvalDataset(
|
256 |
+
rec_folder=self.rec_folder,
|
257 |
+
target='lfw',
|
258 |
+
crop_param=self.crop_param
|
259 |
+
)
|
260 |
+
test_loader = DataLoader(
|
261 |
+
dataset, 20, num_workers=12, shuffle=False, drop_last=False
|
262 |
+
)
|
263 |
+
return test_loader
|
264 |
+
|
265 |
+
|
266 |
+
def start_train():
|
267 |
+
import os
|
268 |
+
import argparse
|
269 |
+
import torch
|
270 |
+
import pytorch_lightning as pl
|
271 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
272 |
+
import wandb
|
273 |
+
from pytorch_lightning.loggers import WandbLogger
|
274 |
+
|
275 |
+
parser = argparse.ArgumentParser()
|
276 |
+
parser.add_argument(
|
277 |
+
"-g",
|
278 |
+
"--gpus",
|
279 |
+
type=str,
|
280 |
+
default=None,
|
281 |
+
help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.",
|
282 |
+
)
|
283 |
+
parser.add_argument("-n", "--name", type=str, required=True, help="Name of the run.")
|
284 |
+
parser.add_argument("-pj", "--project", type=str, default="mouthnet", help="Name of the project.")
|
285 |
+
|
286 |
+
parser.add_argument("-rp", "--resume_checkpoint_path",
|
287 |
+
type=str, default=None, help="path of checkpoint for resuming", )
|
288 |
+
parser.add_argument("-p", "--saving_folder",
|
289 |
+
type=str, default="/apdcephfs/share_1290939/gavinyuan/out", help="saving folder", )
|
290 |
+
parser.add_argument("--wandb_resume",
|
291 |
+
type=str, default=None, help="resume wandb logging from the input id", )
|
292 |
+
|
293 |
+
parser.add_argument("--header_type", type=str, default="AMArcFace", help="loss type.")
|
294 |
+
|
295 |
+
parser.add_argument("-bs", "--batch_size", type=int, default=128, help="bs.")
|
296 |
+
parser.add_argument("-fs", "--fast_dev_run", type=bool, default=False, help="pytorch.lightning fast_dev_run")
|
297 |
+
args = parser.parse_args()
|
298 |
+
args.val_targets = []
|
299 |
+
# args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface"
|
300 |
+
# num_classes = 93431
|
301 |
+
args.rec_folder = "/gavin/datasets/msml/casia"
|
302 |
+
num_classes = 10572
|
303 |
+
|
304 |
+
save_path = os.path.join(args.saving_folder, args.name)
|
305 |
+
os.makedirs(save_path, exist_ok=True)
|
306 |
+
checkpoint_callback = ModelCheckpoint(
|
307 |
+
dirpath=save_path,
|
308 |
+
monitor="train / cls_loss",
|
309 |
+
save_top_k=10,
|
310 |
+
verbose=True,
|
311 |
+
every_n_train_steps=200,
|
312 |
+
)
|
313 |
+
|
314 |
+
torch.cuda.empty_cache()
|
315 |
+
mouth_net = MouthNetPL(
|
316 |
+
num_classes=num_classes,
|
317 |
+
batch_size=args.batch_size,
|
318 |
+
dim_feature=128,
|
319 |
+
rec_folder=args.rec_folder,
|
320 |
+
header_type=args.header_type,
|
321 |
+
crop=(28, 56, 84, 112)
|
322 |
+
)
|
323 |
+
|
324 |
+
if args.wandb_resume == None:
|
325 |
+
resume = "allow"
|
326 |
+
wandb_id = wandb.util.generate_id()
|
327 |
+
else:
|
328 |
+
resume = True
|
329 |
+
wandb_id = args.wandb_resume
|
330 |
+
logger = WandbLogger(
|
331 |
+
project=args.project,
|
332 |
+
entity="gavinyuan",
|
333 |
+
name=args.name,
|
334 |
+
resume=resume,
|
335 |
+
id=wandb_id,
|
336 |
+
)
|
337 |
+
|
338 |
+
trainer = pl.Trainer(
|
339 |
+
gpus=-1 if args.gpus is None else torch.cuda.device_count(),
|
340 |
+
callbacks=[checkpoint_callback],
|
341 |
+
logger=logger,
|
342 |
+
weights_save_path=save_path,
|
343 |
+
resume_from_checkpoint=args.resume_checkpoint_path,
|
344 |
+
gradient_clip_val=0,
|
345 |
+
max_epochs=25,
|
346 |
+
num_sanity_val_steps=1,
|
347 |
+
fast_dev_run=args.fast_dev_run,
|
348 |
+
val_check_interval=50,
|
349 |
+
progress_bar_refresh_rate=1,
|
350 |
+
distributed_backend="ddp",
|
351 |
+
benchmark=True,
|
352 |
+
)
|
353 |
+
trainer.fit(mouth_net)
|
354 |
+
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
|
358 |
+
start_train()
|
third_party/arcface/resnet.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.models import resnet50
|
third_party/arcface/utils_callbacks.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from third_party.arcface import verification
|
9 |
+
|
10 |
+
|
11 |
+
class AverageMeter(object):
|
12 |
+
""" Computes and stores the average and current value
|
13 |
+
"""
|
14 |
+
def __init__(self):
|
15 |
+
self.val = None
|
16 |
+
self.avg = None
|
17 |
+
self.sum = None
|
18 |
+
self.count = None
|
19 |
+
self.reset()
|
20 |
+
|
21 |
+
def reset(self):
|
22 |
+
self.val = 0
|
23 |
+
self.avg = 0
|
24 |
+
self.sum = 0
|
25 |
+
self.count = 0
|
26 |
+
|
27 |
+
def update(self, val, n=1):
|
28 |
+
self.val = val
|
29 |
+
self.sum += val * n
|
30 |
+
self.count += n
|
31 |
+
self.avg = self.sum / self.count
|
32 |
+
|
33 |
+
|
34 |
+
class CallBackVerification(object):
|
35 |
+
def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112),
|
36 |
+
is_gray=False):
|
37 |
+
self.frequent: int = frequent
|
38 |
+
self.rank: int = rank
|
39 |
+
self.highest_acc: float = 0.0
|
40 |
+
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
|
41 |
+
self.ver_list: List[object] = []
|
42 |
+
self.ver_name_list: List[str] = []
|
43 |
+
if self.rank is 0:
|
44 |
+
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
|
45 |
+
self.is_gray = is_gray
|
46 |
+
|
47 |
+
def ver_test(self, backbone: torch.nn.Module, global_step: int):
|
48 |
+
results = []
|
49 |
+
for i in range(len(self.ver_list)):
|
50 |
+
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
|
51 |
+
self.ver_list[i], backbone, 10, 10,
|
52 |
+
is_gray=self.is_gray)
|
53 |
+
# logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
|
54 |
+
# logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
|
55 |
+
print('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
|
56 |
+
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
|
57 |
+
if acc2 > self.highest_acc_list[i]:
|
58 |
+
self.highest_acc_list[i] = acc2
|
59 |
+
# logging.info(
|
60 |
+
# '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
|
61 |
+
print(
|
62 |
+
'[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
|
63 |
+
results.append(acc2)
|
64 |
+
|
65 |
+
def init_dataset(self, val_targets, data_dir, image_size):
|
66 |
+
for name in val_targets:
|
67 |
+
path = os.path.join(data_dir, name + ".bin")
|
68 |
+
if os.path.exists(path):
|
69 |
+
data_set = verification.load_bin(path, image_size)
|
70 |
+
self.ver_list.append(data_set)
|
71 |
+
self.ver_name_list.append(name)
|
72 |
+
|
73 |
+
def __call__(self, num_update, backbone: torch.nn.Module):
|
74 |
+
if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
|
75 |
+
backbone.eval()
|
76 |
+
self.ver_test(backbone, num_update)
|
77 |
+
backbone.train()
|
78 |
+
|
79 |
+
|
80 |
+
class CallBackLogging(object):
|
81 |
+
def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
|
82 |
+
self.frequent: int = frequent
|
83 |
+
self.rank: int = rank
|
84 |
+
self.time_start = time.time()
|
85 |
+
self.total_step: int = total_step
|
86 |
+
self.batch_size: int = batch_size
|
87 |
+
self.world_size: int = world_size
|
88 |
+
self.writer = writer
|
89 |
+
|
90 |
+
self.init = False
|
91 |
+
self.tic = 0
|
92 |
+
|
93 |
+
def __call__(self, global_step, loss: AverageMeter, epoch: int, fp16: bool, grad_scaler: torch.cuda.amp.GradScaler):
|
94 |
+
if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0:
|
95 |
+
if self.init:
|
96 |
+
try:
|
97 |
+
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
|
98 |
+
speed_total = speed * self.world_size
|
99 |
+
except ZeroDivisionError:
|
100 |
+
speed_total = float('inf')
|
101 |
+
|
102 |
+
time_now = (time.time() - self.time_start) / 3600
|
103 |
+
time_total = time_now / ((global_step + 1) / self.total_step)
|
104 |
+
time_for_end = time_total - time_now
|
105 |
+
if self.writer is not None:
|
106 |
+
self.writer.add_scalar('time_for_end', time_for_end, global_step)
|
107 |
+
self.writer.add_scalar('loss', loss.avg, global_step)
|
108 |
+
if fp16:
|
109 |
+
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d "\
|
110 |
+
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
|
111 |
+
speed_total, loss.avg, epoch, global_step, grad_scaler.get_scale(), time_for_end
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
|
115 |
+
speed_total, loss.avg, epoch, global_step, time_for_end
|
116 |
+
)
|
117 |
+
logging.info(msg)
|
118 |
+
loss.reset()
|
119 |
+
self.tic = time.time()
|
120 |
+
else:
|
121 |
+
self.init = True
|
122 |
+
self.tic = time.time()
|
123 |
+
|
124 |
+
|
125 |
+
class CallBackModelCheckpoint(object):
|
126 |
+
def __init__(self, rank, output="./"):
|
127 |
+
self.rank: int = rank
|
128 |
+
self.output: str = output
|
129 |
+
|
130 |
+
def __call__(self,
|
131 |
+
global_step,
|
132 |
+
backbone: torch.nn.Module,
|
133 |
+
partial_fc=None,
|
134 |
+
awloss=None,):
|
135 |
+
print('CallBackModelCheckpoint...')
|
136 |
+
if global_step > 100 and self.rank is 0:
|
137 |
+
torch.save(backbone.module.state_dict(), os.path.join(self.output, "backbone.pth"))
|
138 |
+
if global_step > 100 and partial_fc is not None:
|
139 |
+
partial_fc.save_params()
|
140 |
+
if global_step > 100 and awloss is not None:
|
141 |
+
torch.save(awloss.state_dict(), os.path.join(self.output, "awloss.pth"))
|
third_party/arcface/verification.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper for evaluation on the Labeled Faces in the Wild dataset
|
2 |
+
"""
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2016 David Sandberg
|
7 |
+
#
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to the following conditions:
|
14 |
+
#
|
15 |
+
# The above copyright notice and this permission notice shall be included in all
|
16 |
+
# copies or substantial portions of the Software.
|
17 |
+
#
|
18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
# SOFTWARE.
|
25 |
+
|
26 |
+
|
27 |
+
import datetime
|
28 |
+
import os
|
29 |
+
import pickle
|
30 |
+
|
31 |
+
import mxnet as mx
|
32 |
+
import numpy as np
|
33 |
+
import sklearn
|
34 |
+
import torch
|
35 |
+
from mxnet import ndarray as nd
|
36 |
+
from scipy import interpolate
|
37 |
+
from sklearn.decomposition import PCA
|
38 |
+
from sklearn.model_selection import KFold
|
39 |
+
|
40 |
+
|
41 |
+
class LFold:
|
42 |
+
def __init__(self, n_splits=2, shuffle=False):
|
43 |
+
self.n_splits = n_splits
|
44 |
+
if self.n_splits > 1:
|
45 |
+
self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
|
46 |
+
|
47 |
+
def split(self, indices):
|
48 |
+
if self.n_splits > 1:
|
49 |
+
return self.k_fold.split(indices)
|
50 |
+
else:
|
51 |
+
return [(indices, indices)]
|
52 |
+
|
53 |
+
|
54 |
+
def calculate_roc(thresholds,
|
55 |
+
embeddings1,
|
56 |
+
embeddings2,
|
57 |
+
actual_issame,
|
58 |
+
nrof_folds=10,
|
59 |
+
pca=0):
|
60 |
+
assert (embeddings1.shape[0] == embeddings2.shape[0])
|
61 |
+
assert (embeddings1.shape[1] == embeddings2.shape[1])
|
62 |
+
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
63 |
+
nrof_thresholds = len(thresholds)
|
64 |
+
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
65 |
+
|
66 |
+
tprs = np.zeros((nrof_folds, nrof_thresholds))
|
67 |
+
fprs = np.zeros((nrof_folds, nrof_thresholds))
|
68 |
+
accuracy = np.zeros((nrof_folds))
|
69 |
+
indices = np.arange(nrof_pairs)
|
70 |
+
|
71 |
+
if pca == 0:
|
72 |
+
diff = np.subtract(embeddings1, embeddings2)
|
73 |
+
dist = np.sum(np.square(diff), 1)
|
74 |
+
print('dist', dist.min(), dist.max())
|
75 |
+
|
76 |
+
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
77 |
+
if pca > 0:
|
78 |
+
print('doing pca on', fold_idx)
|
79 |
+
embed1_train = embeddings1[train_set]
|
80 |
+
embed2_train = embeddings2[train_set]
|
81 |
+
_embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
|
82 |
+
pca_model = PCA(n_components=pca)
|
83 |
+
pca_model.fit(_embed_train)
|
84 |
+
embed1 = pca_model.transform(embeddings1)
|
85 |
+
embed2 = pca_model.transform(embeddings2)
|
86 |
+
embed1 = sklearn.preprocessing.normalize(embed1)
|
87 |
+
embed2 = sklearn.preprocessing.normalize(embed2)
|
88 |
+
diff = np.subtract(embed1, embed2)
|
89 |
+
dist = np.sum(np.square(diff), 1)
|
90 |
+
|
91 |
+
# Find the best threshold for the fold
|
92 |
+
acc_train = np.zeros((nrof_thresholds))
|
93 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
94 |
+
_, _, acc_train[threshold_idx] = calculate_accuracy(
|
95 |
+
threshold, dist[train_set], actual_issame[train_set])
|
96 |
+
best_threshold_index = np.argmax(acc_train)
|
97 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
98 |
+
tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
|
99 |
+
threshold, dist[test_set],
|
100 |
+
actual_issame[test_set])
|
101 |
+
_, _, accuracy[fold_idx] = calculate_accuracy(
|
102 |
+
thresholds[best_threshold_index], dist[test_set],
|
103 |
+
actual_issame[test_set])
|
104 |
+
|
105 |
+
tpr = np.mean(tprs, 0)
|
106 |
+
fpr = np.mean(fprs, 0)
|
107 |
+
return tpr, fpr, accuracy
|
108 |
+
|
109 |
+
|
110 |
+
def calculate_accuracy(threshold, dist, actual_issame):
|
111 |
+
predict_issame = np.less(dist, threshold)
|
112 |
+
tp = np.sum(np.logical_and(predict_issame, actual_issame))
|
113 |
+
fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
114 |
+
tn = np.sum(
|
115 |
+
np.logical_and(np.logical_not(predict_issame),
|
116 |
+
np.logical_not(actual_issame)))
|
117 |
+
fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
|
118 |
+
|
119 |
+
tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
|
120 |
+
fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
|
121 |
+
acc = float(tp + tn) / dist.size
|
122 |
+
return tpr, fpr, acc
|
123 |
+
|
124 |
+
|
125 |
+
def calculate_val(thresholds,
|
126 |
+
embeddings1,
|
127 |
+
embeddings2,
|
128 |
+
actual_issame,
|
129 |
+
far_target,
|
130 |
+
nrof_folds=10):
|
131 |
+
assert (embeddings1.shape[0] == embeddings2.shape[0])
|
132 |
+
assert (embeddings1.shape[1] == embeddings2.shape[1])
|
133 |
+
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
|
134 |
+
nrof_thresholds = len(thresholds)
|
135 |
+
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
|
136 |
+
|
137 |
+
val = np.zeros(nrof_folds)
|
138 |
+
far = np.zeros(nrof_folds)
|
139 |
+
|
140 |
+
diff = np.subtract(embeddings1, embeddings2)
|
141 |
+
dist = np.sum(np.square(diff), 1)
|
142 |
+
indices = np.arange(nrof_pairs)
|
143 |
+
|
144 |
+
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
|
145 |
+
|
146 |
+
# Find the threshold that gives FAR = far_target
|
147 |
+
far_train = np.zeros(nrof_thresholds)
|
148 |
+
for threshold_idx, threshold in enumerate(thresholds):
|
149 |
+
_, far_train[threshold_idx] = calculate_val_far(
|
150 |
+
threshold, dist[train_set], actual_issame[train_set])
|
151 |
+
if np.max(far_train) >= far_target:
|
152 |
+
f = interpolate.interp1d(far_train, thresholds, kind='slinear')
|
153 |
+
threshold = f(far_target)
|
154 |
+
else:
|
155 |
+
threshold = 0.0
|
156 |
+
|
157 |
+
val[fold_idx], far[fold_idx] = calculate_val_far(
|
158 |
+
threshold, dist[test_set], actual_issame[test_set])
|
159 |
+
|
160 |
+
val_mean = np.mean(val)
|
161 |
+
far_mean = np.mean(far)
|
162 |
+
val_std = np.std(val)
|
163 |
+
return val_mean, val_std, far_mean
|
164 |
+
|
165 |
+
|
166 |
+
def calculate_val_far(threshold, dist, actual_issame):
|
167 |
+
predict_issame = np.less(dist, threshold)
|
168 |
+
true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
|
169 |
+
false_accept = np.sum(
|
170 |
+
np.logical_and(predict_issame, np.logical_not(actual_issame)))
|
171 |
+
n_same = np.sum(actual_issame)
|
172 |
+
n_diff = np.sum(np.logical_not(actual_issame))
|
173 |
+
# print(true_accept, false_accept)
|
174 |
+
# print(actual_issame)
|
175 |
+
# print(n_same, n_diff)
|
176 |
+
val = float(true_accept) / float(n_same)
|
177 |
+
far = float(false_accept) / float(n_diff)
|
178 |
+
return val, far
|
179 |
+
|
180 |
+
|
181 |
+
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
|
182 |
+
# Calculate evaluation metrics
|
183 |
+
thresholds = np.arange(0, 4, 0.01)
|
184 |
+
embeddings1 = embeddings[0::2]
|
185 |
+
embeddings2 = embeddings[1::2]
|
186 |
+
tpr, fpr, accuracy = calculate_roc(thresholds,
|
187 |
+
embeddings1,
|
188 |
+
embeddings2,
|
189 |
+
np.asarray(actual_issame),
|
190 |
+
nrof_folds=nrof_folds,
|
191 |
+
pca=pca)
|
192 |
+
thresholds = np.arange(0, 4, 0.001)
|
193 |
+
val, val_std, far = calculate_val(thresholds,
|
194 |
+
embeddings1,
|
195 |
+
embeddings2,
|
196 |
+
np.asarray(actual_issame),
|
197 |
+
1e-3,
|
198 |
+
nrof_folds=nrof_folds)
|
199 |
+
return tpr, fpr, accuracy, val, val_std, far
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def load_bin(path, image_size):
|
203 |
+
try:
|
204 |
+
with open(path, 'rb') as f:
|
205 |
+
bins, issame_list = pickle.load(f) # py2
|
206 |
+
except UnicodeDecodeError as e:
|
207 |
+
with open(path, 'rb') as f:
|
208 |
+
bins, issame_list = pickle.load(f, encoding='bytes') # py3
|
209 |
+
data_list = []
|
210 |
+
for flip in [0, 1]:
|
211 |
+
data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
|
212 |
+
data_list.append(data)
|
213 |
+
for idx in range(len(issame_list) * 2):
|
214 |
+
_bin = bins[idx]
|
215 |
+
img = mx.image.imdecode(_bin)
|
216 |
+
if img.shape[1] != image_size[0]:
|
217 |
+
img = mx.image.resize_short(img, image_size[0])
|
218 |
+
img = nd.transpose(img, axes=(2, 0, 1)) # (C, H, W)
|
219 |
+
for flip in [0, 1]:
|
220 |
+
if flip == 1:
|
221 |
+
img = mx.ndarray.flip(data=img, axis=2)
|
222 |
+
data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
|
223 |
+
if idx % 1000 == 0:
|
224 |
+
print('loading bin', idx)
|
225 |
+
|
226 |
+
# # save img to '/home/yuange/dataset/LFW/rgb-arcface'
|
227 |
+
# img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
|
228 |
+
# save_name = 'ind_' + str(idx) + '.bmp'
|
229 |
+
# import os
|
230 |
+
# save_name = os.path.join('/home/yuange/dataset/LFW/rgb-arcface', save_name)
|
231 |
+
# import PIL.Image as Image
|
232 |
+
# fig = Image.fromarray(img.asnumpy(), mode='RGB')
|
233 |
+
# fig.save(save_name)
|
234 |
+
|
235 |
+
print('load finished', data_list[0].shape)
|
236 |
+
return data_list, issame_list
|
237 |
+
|
238 |
+
@torch.no_grad()
|
239 |
+
def test(data_set, backbone, batch_size, nfolds=10,
|
240 |
+
is_gray=False,):
|
241 |
+
print('testing verification..')
|
242 |
+
data_list = data_set[0]
|
243 |
+
issame_list = data_set[1]
|
244 |
+
embeddings_list = []
|
245 |
+
time_consumed = 0.0
|
246 |
+
for i in range(len(data_list)):
|
247 |
+
data = data_list[i] # (12000, 3, 112, 112)
|
248 |
+
|
249 |
+
print(data.shape)
|
250 |
+
if is_gray:
|
251 |
+
data = (0.2989 * data[:, 0] +
|
252 |
+
0.5870 * data[:, 1] +
|
253 |
+
0.1140 * data[:, 2]) / 3
|
254 |
+
data = data[:, None, :, :]
|
255 |
+
print(data.shape)
|
256 |
+
|
257 |
+
embeddings = None
|
258 |
+
ba = 0
|
259 |
+
while ba < data.shape[0]:
|
260 |
+
bb = min(ba + batch_size, data.shape[0])
|
261 |
+
count = bb - ba
|
262 |
+
_data = data[bb - batch_size: bb]
|
263 |
+
time0 = datetime.datetime.now()
|
264 |
+
|
265 |
+
if not is_gray:
|
266 |
+
img = ((_data / 255) - 0.5) / 0.5
|
267 |
+
else:
|
268 |
+
img = _data / 255
|
269 |
+
|
270 |
+
# mouth_net returns a feature whether in training or testing
|
271 |
+
feature = backbone(img.cuda(0))
|
272 |
+
net_out: torch.Tensor = feature
|
273 |
+
|
274 |
+
_embeddings = net_out.detach().cpu().numpy()
|
275 |
+
time_now = datetime.datetime.now()
|
276 |
+
diff = time_now - time0
|
277 |
+
time_consumed += diff.total_seconds()
|
278 |
+
if embeddings is None:
|
279 |
+
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
|
280 |
+
embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
|
281 |
+
ba = bb
|
282 |
+
embeddings_list.append(embeddings)
|
283 |
+
|
284 |
+
print('emb_list', len(embeddings_list), embeddings_list[0].size, embeddings_list[1].size)
|
285 |
+
_xnorm = 0.0
|
286 |
+
_xnorm_cnt = 0
|
287 |
+
for embed in embeddings_list:
|
288 |
+
for i in range(embed.shape[0]):
|
289 |
+
_em = embed[i]
|
290 |
+
_norm = np.linalg.norm(_em)
|
291 |
+
_xnorm += _norm
|
292 |
+
_xnorm_cnt += 1
|
293 |
+
_xnorm /= _xnorm_cnt
|
294 |
+
|
295 |
+
embeddings = embeddings_list[0].copy()
|
296 |
+
embeddings = sklearn.preprocessing.normalize(embeddings)
|
297 |
+
acc1 = 0.0
|
298 |
+
std1 = 0.0
|
299 |
+
embeddings = embeddings_list[0] + embeddings_list[1]
|
300 |
+
embeddings = sklearn.preprocessing.normalize(embeddings)
|
301 |
+
print('embeddings.shape', embeddings.shape)
|
302 |
+
print('infer time', time_consumed)
|
303 |
+
_, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
|
304 |
+
acc2, std2 = np.mean(accuracy), np.std(accuracy)
|
305 |
+
return acc1, std1, acc2, std2, _xnorm, embeddings_list
|
306 |
+
|
307 |
+
|
308 |
+
def dumpR(data_set,
|
309 |
+
backbone,
|
310 |
+
batch_size,
|
311 |
+
name='',
|
312 |
+
data_extra=None,
|
313 |
+
label_shape=None):
|
314 |
+
print('dump verification embedding..')
|
315 |
+
data_list = data_set[0]
|
316 |
+
issame_list = data_set[1]
|
317 |
+
embeddings_list = []
|
318 |
+
time_consumed = 0.0
|
319 |
+
for i in range(len(data_list)):
|
320 |
+
data = data_list[i]
|
321 |
+
embeddings = None
|
322 |
+
ba = 0
|
323 |
+
while ba < data.shape[0]:
|
324 |
+
bb = min(ba + batch_size, data.shape[0])
|
325 |
+
count = bb - ba
|
326 |
+
|
327 |
+
_data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
|
328 |
+
time0 = datetime.datetime.now()
|
329 |
+
if data_extra is None:
|
330 |
+
db = mx.io.DataBatch(data=(_data,), label=(_label,))
|
331 |
+
else:
|
332 |
+
db = mx.io.DataBatch(data=(_data, _data_extra),
|
333 |
+
label=(_label,))
|
334 |
+
model.forward(db, is_train=False)
|
335 |
+
net_out = model.get_outputs()
|
336 |
+
_embeddings = net_out[0].asnumpy()
|
337 |
+
time_now = datetime.datetime.now()
|
338 |
+
diff = time_now - time0
|
339 |
+
time_consumed += diff.total_seconds()
|
340 |
+
if embeddings is None:
|
341 |
+
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
|
342 |
+
embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
|
343 |
+
ba = bb
|
344 |
+
embeddings_list.append(embeddings)
|
345 |
+
embeddings = embeddings_list[0] + embeddings_list[1]
|
346 |
+
embeddings = sklearn.preprocessing.normalize(embeddings)
|
347 |
+
actual_issame = np.asarray(issame_list)
|
348 |
+
outname = os.path.join('temp.bin')
|
349 |
+
with open(outname, 'wb') as f:
|
350 |
+
pickle.dump((embeddings, issame_list),
|
351 |
+
f,
|
352 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
353 |
+
|
354 |
+
|
355 |
+
# if __name__ == '__main__':
|
356 |
+
#
|
357 |
+
# parser = argparse.ArgumentParser(description='do verification')
|
358 |
+
# # general
|
359 |
+
# parser.add_argument('--data-dir', default='', help='')
|
360 |
+
# parser.add_argument('--model',
|
361 |
+
# default='../model/softmax,50',
|
362 |
+
# help='path to load model.')
|
363 |
+
# parser.add_argument('--target',
|
364 |
+
# default='lfw,cfp_ff,cfp_fp,agedb_30',
|
365 |
+
# help='test targets.')
|
366 |
+
# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
|
367 |
+
# parser.add_argument('--batch-size', default=32, type=int, help='')
|
368 |
+
# parser.add_argument('--max', default='', type=str, help='')
|
369 |
+
# parser.add_argument('--mode', default=0, type=int, help='')
|
370 |
+
# parser.add_argument('--nfolds', default=10, type=int, help='')
|
371 |
+
# args = parser.parse_args()
|
372 |
+
# image_size = [112, 112]
|
373 |
+
# print('image_size', image_size)
|
374 |
+
# ctx = mx.gpu(args.gpu)
|
375 |
+
# nets = []
|
376 |
+
# vec = args.model.split(',')
|
377 |
+
# prefix = args.model.split(',')[0]
|
378 |
+
# epochs = []
|
379 |
+
# if len(vec) == 1:
|
380 |
+
# pdir = os.path.dirname(prefix)
|
381 |
+
# for fname in os.listdir(pdir):
|
382 |
+
# if not fname.endswith('.params'):
|
383 |
+
# continue
|
384 |
+
# _file = os.path.join(pdir, fname)
|
385 |
+
# if _file.startswith(prefix):
|
386 |
+
# epoch = int(fname.split('.')[0].split('-')[1])
|
387 |
+
# epochs.append(epoch)
|
388 |
+
# epochs = sorted(epochs, reverse=True)
|
389 |
+
# if len(args.max) > 0:
|
390 |
+
# _max = [int(x) for x in args.max.split(',')]
|
391 |
+
# assert len(_max) == 2
|
392 |
+
# if len(epochs) > _max[1]:
|
393 |
+
# epochs = epochs[_max[0]:_max[1]]
|
394 |
+
#
|
395 |
+
# else:
|
396 |
+
# epochs = [int(x) for x in vec[1].split('|')]
|
397 |
+
# print('model number', len(epochs))
|
398 |
+
# time0 = datetime.datetime.now()
|
399 |
+
# for epoch in epochs:
|
400 |
+
# print('loading', prefix, epoch)
|
401 |
+
# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
402 |
+
# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
|
403 |
+
# all_layers = sym.get_internals()
|
404 |
+
# sym = all_layers['fc1_output']
|
405 |
+
# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
|
406 |
+
# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
407 |
+
# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
|
408 |
+
# image_size[1]))])
|
409 |
+
# model.set_params(arg_params, aux_params)
|
410 |
+
# nets.append(model)
|
411 |
+
# time_now = datetime.datetime.now()
|
412 |
+
# diff = time_now - time0
|
413 |
+
# print('model loading time', diff.total_seconds())
|
414 |
+
#
|
415 |
+
# ver_list = []
|
416 |
+
# ver_name_list = []
|
417 |
+
# for name in args.target.split(','):
|
418 |
+
# path = os.path.join(args.data_dir, name + ".bin")
|
419 |
+
# if os.path.exists(path):
|
420 |
+
# print('loading.. ', name)
|
421 |
+
# data_set = load_bin(path, image_size)
|
422 |
+
# ver_list.append(data_set)
|
423 |
+
# ver_name_list.append(name)
|
424 |
+
#
|
425 |
+
# if args.mode == 0:
|
426 |
+
# for i in range(len(ver_list)):
|
427 |
+
# results = []
|
428 |
+
# for model in nets:
|
429 |
+
# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
|
430 |
+
# ver_list[i], model, args.batch_size, args.nfolds)
|
431 |
+
# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
|
432 |
+
# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
|
433 |
+
# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
|
434 |
+
# results.append(acc2)
|
435 |
+
# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
|
436 |
+
# elif args.mode == 1:
|
437 |
+
# raise ValueError
|
438 |
+
# else:
|
439 |
+
# model = nets[0]
|
440 |
+
# dumpR(ver_list[0], model, args.batch_size, args.target)
|