wmpscc commited on
Commit
7d1312d
1 Parent(s): 6aa2a6e
Files changed (39) hide show
  1. app.py +56 -0
  2. configs.py +9 -0
  3. data/fairface_gender_angle.csv +0 -0
  4. img/demo.png +0 -0
  5. img/pic_top.jpg +0 -0
  6. models/__init__.py +0 -0
  7. models/__pycache__/__init__.cpython-39.pyc +0 -0
  8. models/encoders/__init__.py +0 -0
  9. models/encoders/helpers.py +140 -0
  10. models/encoders/model_irse.py +84 -0
  11. models/encoders/psp_encoders.py +236 -0
  12. models/stylegan2/__init__.py +0 -0
  13. models/stylegan2/__pycache__/__init__.cpython-39.pyc +0 -0
  14. models/stylegan2/__pycache__/model.cpython-39.pyc +0 -0
  15. models/stylegan2/model.py +674 -0
  16. models/stylegan2/op/__init__.py +2 -0
  17. models/stylegan2/op/__pycache__/__init__.cpython-39.pyc +0 -0
  18. models/stylegan2/op/__pycache__/fused_act.cpython-39.pyc +0 -0
  19. models/stylegan2/op/fused_act.py +86 -0
  20. models/stylegan2/op/fused_bias_act.cpp +21 -0
  21. models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
  22. models/stylegan2/op/upfirdn2d.cpp +23 -0
  23. models/stylegan2/op/upfirdn2d.py +186 -0
  24. models/stylegan2/op/upfirdn2d_kernel.cu +272 -0
  25. models/stylegene/__init__.py +0 -0
  26. models/stylegene/__pycache__/__init__.cpython-39.pyc +0 -0
  27. models/stylegene/__pycache__/api.cpython-39.pyc +0 -0
  28. models/stylegene/api.py +94 -0
  29. models/stylegene/data_util.py +36 -0
  30. models/stylegene/fair_face_model.py +61 -0
  31. models/stylegene/gene_crossover_mutation.py +64 -0
  32. models/stylegene/gene_pool.py +42 -0
  33. models/stylegene/model.py +209 -0
  34. models/stylegene/util.py +30 -0
  35. preprocess/__init__.py +0 -0
  36. preprocess/align_images.py +32 -0
  37. preprocess/face_alignment.py +87 -0
  38. preprocess/landmarks_detector.py +24 -0
  39. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ #from models.stylegene.api import synthesize_descendant
4
+
5
+ description = """<p style="text-align: center; font-weight: bold;">
6
+ <span style="font-size: 28px">StyleGene: Crossover and Mutation of Region-Level Facial Genes for Kinship Face Synthesis</span>
7
+ <br>
8
+ <span style="font-size: 18px" id="paper-info">
9
+ [<a href="https://wmpscc.github.io/stylegene/" target="_blank">Project Page</a>]
10
+ [<a href="https://openaccess.thecvf.com/content/CVPR2023/papers/Li_StyleGene_Crossover_and_Mutation_of_Region-Level_Facial_Genes_for_Kinship_CVPR_2023_paper.pdf" target="_blank">Paper</a>]
11
+ [<a href="https://github.com/CVI-SZU/StyleGene" target="_blank">GitHub</a>]
12
+ </span>
13
+ <br>
14
+ <a> Tips: One picture should have only one face.</a>
15
+ </p>"""
16
+
17
+ block = gr.Blocks()
18
+ with block:
19
+ gr.HTML(description)
20
+ with gr.Row():
21
+ with gr.Column():
22
+ gr.Markdown("### Upload photos of father and mother")
23
+ with gr.Row():
24
+ img1 = gr.Image(label="Father")
25
+ img2 = gr.Image(label="Mother")
26
+ gr.Markdown("### Select the child's age and gender")
27
+ with gr.Row():
28
+ age = gr.Dropdown(label="Age",
29
+ choices=["0-2", "3-9", "10-19", "20-29", "30-39",
30
+ "40-49", "50-59", "60-69", "70+"], value="3-9")
31
+ gender = gr.Dropdown(label="Gender", choices=["male", "female"], value="female")
32
+ gr.Markdown("### Adjust your child's resemblance to parents")
33
+ bar1 = gr.Slider(label="gamma", minimum=0, maximum=1, value=0.47)
34
+ bar2 = gr.Slider(label="eta", minimum=0, maximum=1, value=0.4)
35
+ bt_run = gr.Button("Run")
36
+ gr.Markdown("""## Disclaimer
37
+ This method is intended for academic research purposes only and is strictly prohibited for commercial use.
38
+ Users are required to comply with all local laws and regulations when using this method.""")
39
+
40
+ with gr.Column():
41
+ gr.Markdown("### Results")
42
+ img3 = gr.Image(label="Generated child")
43
+ with gr.Row():
44
+ img1_align = gr.Image(label="Father")
45
+ img2_align = gr.Image(label="Mother")
46
+
47
+
48
+ def run(father, mother, gamma, eta, age, gender):
49
+ attributes = {'age': age, 'gender': gender, 'gamma': float(gamma), 'eta': float(eta)}
50
+ img_F, img_M, img_C = synthesize_descendant(father, mother, attributes)
51
+ return img_F, img_M, img_C
52
+
53
+
54
+ bt_run.click(run, [img1, img2, bar1, bar2, age, gender], [img1_align, img2_align, img3])
55
+
56
+ block.launch(show_error=True)
configs.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ path_ckpt_landmark68 = "checkpoints/shape_predictor_68_face_landmarks.dat.bz2"
2
+ path_ckpt_e4e = "/home/cvi_demo/PythonProject/StyleGene/checkpoints/e4e_ffhq_encode.pt"
3
+ path_ckpt_stylegan2 = '/home/cvi_demo/PythonProject/StyleGene/checkpoints/stylegan2-ffhq-config-f.pt'
4
+ path_ckpt_stylegene = "/home/cvi_demo/PythonProject/StyleGene/checkpoints/stylegene_N18.ckpt"
5
+ path_ckpt_fairface = '/home/cvi_demo/PythonProject/StyleGene/checkpoints/res34_fair_align_multi_7_20190809.pt'
6
+ path_ckpt_genepool = "/home/cvi_demo/PythonProject/StyleGene/checkpoints/geneFactorPool.pkl"
7
+
8
+ path_csv_ffhq_attritube = 'data/fairface_gender_angle.csv'
9
+ path_dataset_ffhq = None
data/fairface_gender_angle.csv ADDED
The diff for this file is too large to render. See raw diff
 
img/demo.png ADDED
img/pic_top.jpg ADDED
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (135 Bytes). View file
 
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
5
+
6
+ """
7
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8
+ """
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, input):
13
+ return input.view(input.size(0), -1)
14
+
15
+
16
+ def l2_norm(input, axis=1):
17
+ norm = torch.norm(input, 2, axis, True)
18
+ output = torch.div(input, norm)
19
+ return output
20
+
21
+
22
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
23
+ """ A named tuple describing a ResNet block. """
24
+
25
+
26
+ def get_block(in_channel, depth, num_units, stride=2):
27
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
28
+
29
+
30
+ def get_blocks(num_layers):
31
+ if num_layers == 50:
32
+ blocks = [
33
+ get_block(in_channel=64, depth=64, num_units=3),
34
+ get_block(in_channel=64, depth=128, num_units=4),
35
+ get_block(in_channel=128, depth=256, num_units=14),
36
+ get_block(in_channel=256, depth=512, num_units=3)
37
+ ]
38
+ elif num_layers == 100:
39
+ blocks = [
40
+ get_block(in_channel=64, depth=64, num_units=3),
41
+ get_block(in_channel=64, depth=128, num_units=13),
42
+ get_block(in_channel=128, depth=256, num_units=30),
43
+ get_block(in_channel=256, depth=512, num_units=3)
44
+ ]
45
+ elif num_layers == 152:
46
+ blocks = [
47
+ get_block(in_channel=64, depth=64, num_units=3),
48
+ get_block(in_channel=64, depth=128, num_units=8),
49
+ get_block(in_channel=128, depth=256, num_units=36),
50
+ get_block(in_channel=256, depth=512, num_units=3)
51
+ ]
52
+ else:
53
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
54
+ return blocks
55
+
56
+
57
+ class SEModule(Module):
58
+ def __init__(self, channels, reduction):
59
+ super(SEModule, self).__init__()
60
+ self.avg_pool = AdaptiveAvgPool2d(1)
61
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
62
+ self.relu = ReLU(inplace=True)
63
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
64
+ self.sigmoid = Sigmoid()
65
+
66
+ def forward(self, x):
67
+ module_input = x
68
+ x = self.avg_pool(x)
69
+ x = self.fc1(x)
70
+ x = self.relu(x)
71
+ x = self.fc2(x)
72
+ x = self.sigmoid(x)
73
+ return module_input * x
74
+
75
+
76
+ class bottleneck_IR(Module):
77
+ def __init__(self, in_channel, depth, stride):
78
+ super(bottleneck_IR, self).__init__()
79
+ if in_channel == depth:
80
+ self.shortcut_layer = MaxPool2d(1, stride)
81
+ else:
82
+ self.shortcut_layer = Sequential(
83
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
84
+ BatchNorm2d(depth)
85
+ )
86
+ self.res_layer = Sequential(
87
+ BatchNorm2d(in_channel),
88
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
89
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
90
+ )
91
+
92
+ def forward(self, x):
93
+ shortcut = self.shortcut_layer(x)
94
+ res = self.res_layer(x)
95
+ return res + shortcut
96
+
97
+
98
+ class bottleneck_IR_SE(Module):
99
+ def __init__(self, in_channel, depth, stride):
100
+ super(bottleneck_IR_SE, self).__init__()
101
+ if in_channel == depth:
102
+ self.shortcut_layer = MaxPool2d(1, stride)
103
+ else:
104
+ self.shortcut_layer = Sequential(
105
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
106
+ BatchNorm2d(depth)
107
+ )
108
+ self.res_layer = Sequential(
109
+ BatchNorm2d(in_channel),
110
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
111
+ PReLU(depth),
112
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
113
+ BatchNorm2d(depth),
114
+ SEModule(depth, 16)
115
+ )
116
+
117
+ def forward(self, x):
118
+ shortcut = self.shortcut_layer(x)
119
+ res = self.res_layer(x)
120
+ return res + shortcut
121
+
122
+
123
+ def _upsample_add(x, y):
124
+ """Upsample and add two feature maps.
125
+ Args:
126
+ x: (Variable) top feature map to be upsampled.
127
+ y: (Variable) lateral feature map.
128
+ Returns:
129
+ (Variable) added feature map.
130
+ Note in PyTorch, when input size is odd, the upsampled feature map
131
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
132
+ maybe not equal to the lateral feature map size.
133
+ e.g.
134
+ original input size: [N,_,15,15] ->
135
+ conv2d feature map size: [N,_,8,8] ->
136
+ upsampled feature map size: [N,_,16,16]
137
+ So we choose bilinear upsample which supports arbitrary output sizes.
138
+ """
139
+ _, _, H, W = y.size()
140
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
7
+
8
+ from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
9
+ from models.stylegan2.model import EqualLinear
10
+
11
+
12
+ # Adapted from https://github.com/omertov/encoder4editing
13
+ class ProgressiveStage(Enum):
14
+ WTraining = 0
15
+ Delta1Training = 1
16
+ Delta2Training = 2
17
+ Delta3Training = 3
18
+ Delta4Training = 4
19
+ Delta5Training = 5
20
+ Delta6Training = 6
21
+ Delta7Training = 7
22
+ Delta8Training = 8
23
+ Delta9Training = 9
24
+ Delta10Training = 10
25
+ Delta11Training = 11
26
+ Delta12Training = 12
27
+ Delta13Training = 13
28
+ Delta14Training = 14
29
+ Delta15Training = 15
30
+ Delta16Training = 16
31
+ Delta17Training = 17
32
+ Inference = 18
33
+
34
+
35
+ class GradualStyleBlock(Module):
36
+ def __init__(self, in_c, out_c, spatial):
37
+ super(GradualStyleBlock, self).__init__()
38
+ self.out_c = out_c
39
+ self.spatial = spatial
40
+ num_pools = int(np.log2(spatial))
41
+ modules = []
42
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
43
+ nn.LeakyReLU()]
44
+ for i in range(num_pools - 1):
45
+ modules += [
46
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
47
+ nn.LeakyReLU()
48
+ ]
49
+ self.convs = nn.Sequential(*modules)
50
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
51
+
52
+ def forward(self, x):
53
+ x = self.convs(x)
54
+ x = x.view(-1, self.out_c)
55
+ x = self.linear(x)
56
+ return x
57
+
58
+
59
+ class GradualStyleEncoder(Module):
60
+ def __init__(self, num_layers, mode='ir', opts=None):
61
+ super(GradualStyleEncoder, self).__init__()
62
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
63
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
64
+ blocks = get_blocks(num_layers)
65
+ if mode == 'ir':
66
+ unit_module = bottleneck_IR
67
+ elif mode == 'ir_se':
68
+ unit_module = bottleneck_IR_SE
69
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
70
+ BatchNorm2d(64),
71
+ PReLU(64))
72
+ modules = []
73
+ for block in blocks:
74
+ for bottleneck in block:
75
+ modules.append(unit_module(bottleneck.in_channel,
76
+ bottleneck.depth,
77
+ bottleneck.stride))
78
+ self.body = Sequential(*modules)
79
+
80
+ self.styles = nn.ModuleList()
81
+ log_size = int(math.log(opts.stylegan_size, 2))
82
+ self.style_count = 2 * log_size - 2
83
+ self.coarse_ind = 3
84
+ self.middle_ind = 7
85
+ for i in range(self.style_count):
86
+ if i < self.coarse_ind:
87
+ style = GradualStyleBlock(512, 512, 16)
88
+ elif i < self.middle_ind:
89
+ style = GradualStyleBlock(512, 512, 32)
90
+ else:
91
+ style = GradualStyleBlock(512, 512, 64)
92
+ self.styles.append(style)
93
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
94
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
95
+
96
+ def forward(self, x):
97
+ x = self.input_layer(x)
98
+
99
+ latents = []
100
+ modulelist = list(self.body._modules.values())
101
+ for i, l in enumerate(modulelist):
102
+ x = l(x)
103
+ if i == 6:
104
+ c1 = x
105
+ elif i == 20:
106
+ c2 = x
107
+ elif i == 23:
108
+ c3 = x
109
+
110
+ for j in range(self.coarse_ind):
111
+ latents.append(self.styles[j](c3))
112
+
113
+ p2 = _upsample_add(c3, self.latlayer1(c2))
114
+ for j in range(self.coarse_ind, self.middle_ind):
115
+ latents.append(self.styles[j](p2))
116
+
117
+ p1 = _upsample_add(p2, self.latlayer2(c1))
118
+ for j in range(self.middle_ind, self.style_count):
119
+ latents.append(self.styles[j](p1))
120
+
121
+ out = torch.stack(latents, dim=1)
122
+ return out
123
+
124
+
125
+ class Encoder4Editing(Module):
126
+ def __init__(self, num_layers, mode='ir', stylegan_size=1024):
127
+ super(Encoder4Editing, self).__init__()
128
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
129
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
130
+ blocks = get_blocks(num_layers)
131
+ if mode == 'ir':
132
+ unit_module = bottleneck_IR
133
+ elif mode == 'ir_se':
134
+ unit_module = bottleneck_IR_SE
135
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
136
+ BatchNorm2d(64),
137
+ PReLU(64))
138
+ modules = []
139
+ for block in blocks:
140
+ for bottleneck in block:
141
+ modules.append(unit_module(bottleneck.in_channel,
142
+ bottleneck.depth,
143
+ bottleneck.stride))
144
+ self.body = Sequential(*modules)
145
+
146
+ self.styles = nn.ModuleList()
147
+ log_size = int(math.log(stylegan_size, 2))
148
+ self.style_count = 2 * log_size - 2
149
+ self.coarse_ind = 3
150
+ self.middle_ind = 7
151
+
152
+ for i in range(self.style_count):
153
+ if i < self.coarse_ind:
154
+ style = GradualStyleBlock(512, 512, 16)
155
+ elif i < self.middle_ind:
156
+ style = GradualStyleBlock(512, 512, 32)
157
+ else:
158
+ style = GradualStyleBlock(512, 512, 64)
159
+ self.styles.append(style)
160
+
161
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
162
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
163
+
164
+ self.progressive_stage = ProgressiveStage.Inference
165
+
166
+ def get_deltas_starting_dimensions(self):
167
+ ''' Get a list of the initial dimension of every delta from which it is applied '''
168
+ return list(range(self.style_count)) # Each dimension has a delta applied to it
169
+
170
+ def set_progressive_stage(self, new_stage: ProgressiveStage):
171
+ self.progressive_stage = new_stage
172
+ print('Changed progressive stage to: ', new_stage)
173
+
174
+ def forward(self, x):
175
+ x = self.input_layer(x)
176
+
177
+ modulelist = list(self.body._modules.values())
178
+ for i, l in enumerate(modulelist):
179
+ x = l(x)
180
+ if i == 6:
181
+ c1 = x
182
+ elif i == 20:
183
+ c2 = x
184
+ elif i == 23:
185
+ c3 = x
186
+
187
+ # Infer main W and duplicate it
188
+ w0 = self.styles[0](c3)
189
+ w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
190
+ stage = self.progressive_stage.value
191
+ features = c3
192
+ for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
193
+ if i == self.coarse_ind:
194
+ p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
195
+ features = p2
196
+ elif i == self.middle_ind:
197
+ p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
198
+ features = p1
199
+ delta_i = self.styles[i](features)
200
+ w[:, i] += delta_i
201
+ return w
202
+
203
+
204
+ class BackboneEncoderUsingLastLayerIntoW(Module):
205
+ def __init__(self, num_layers, mode='ir', opts=None):
206
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
207
+ print('Using BackboneEncoderUsingLastLayerIntoW')
208
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
209
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
210
+ blocks = get_blocks(num_layers)
211
+ if mode == 'ir':
212
+ unit_module = bottleneck_IR
213
+ elif mode == 'ir_se':
214
+ unit_module = bottleneck_IR_SE
215
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
216
+ BatchNorm2d(64),
217
+ PReLU(64))
218
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
219
+ self.linear = EqualLinear(512, 512, lr_mul=1)
220
+ modules = []
221
+ for block in blocks:
222
+ for bottleneck in block:
223
+ modules.append(unit_module(bottleneck.in_channel,
224
+ bottleneck.depth,
225
+ bottleneck.stride))
226
+ self.body = Sequential(*modules)
227
+ log_size = int(math.log(opts.stylegan_size, 2))
228
+ self.style_count = 2 * log_size - 2
229
+
230
+ def forward(self, x):
231
+ x = self.input_layer(x)
232
+ x = self.body(x)
233
+ x = self.output_pool(x)
234
+ x = x.view(-1, 512)
235
+ x = self.linear(x)
236
+ return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
models/stylegan2/__pycache__/model.cpython-39.pyc ADDED
Binary file (15.8 kB). View file
 
models/stylegan2/model.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8
+
9
+ # Adapted from https://github.com/rosinality/stylegan2-pytorch
10
+
11
+ class PixelNorm(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, input):
16
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
17
+
18
+
19
+ def make_kernel(k):
20
+ k = torch.tensor(k, dtype=torch.float32)
21
+
22
+ if k.ndim == 1:
23
+ k = k[None, :] * k[:, None]
24
+
25
+ k /= k.sum()
26
+
27
+ return k
28
+
29
+
30
+ class Upsample(nn.Module):
31
+ def __init__(self, kernel, factor=2):
32
+ super().__init__()
33
+
34
+ self.factor = factor
35
+ kernel = make_kernel(kernel) * (factor ** 2)
36
+ self.register_buffer('kernel', kernel)
37
+
38
+ p = kernel.shape[0] - factor
39
+
40
+ pad0 = (p + 1) // 2 + factor - 1
41
+ pad1 = p // 2
42
+
43
+ self.pad = (pad0, pad1)
44
+
45
+ def forward(self, input):
46
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
47
+
48
+ return out
49
+
50
+
51
+ class Downsample(nn.Module):
52
+ def __init__(self, kernel, factor=2):
53
+ super().__init__()
54
+
55
+ self.factor = factor
56
+ kernel = make_kernel(kernel)
57
+ self.register_buffer('kernel', kernel)
58
+
59
+ p = kernel.shape[0] - factor
60
+
61
+ pad0 = (p + 1) // 2
62
+ pad1 = p // 2
63
+
64
+ self.pad = (pad0, pad1)
65
+
66
+ def forward(self, input):
67
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68
+
69
+ return out
70
+
71
+
72
+ class Blur(nn.Module):
73
+ def __init__(self, kernel, pad, upsample_factor=1):
74
+ super().__init__()
75
+
76
+ kernel = make_kernel(kernel)
77
+
78
+ if upsample_factor > 1:
79
+ kernel = kernel * (upsample_factor ** 2)
80
+
81
+ self.register_buffer('kernel', kernel)
82
+
83
+ self.pad = pad
84
+
85
+ def forward(self, input):
86
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
87
+
88
+ return out
89
+
90
+
91
+ class EqualConv2d(nn.Module):
92
+ def __init__(
93
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
94
+ ):
95
+ super().__init__()
96
+
97
+ self.weight = nn.Parameter(
98
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
99
+ )
100
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
101
+
102
+ self.stride = stride
103
+ self.padding = padding
104
+
105
+ if bias:
106
+ self.bias = nn.Parameter(torch.zeros(out_channel))
107
+
108
+ else:
109
+ self.bias = None
110
+
111
+ def forward(self, input):
112
+ out = F.conv2d(
113
+ input,
114
+ self.weight * self.scale,
115
+ bias=self.bias,
116
+ stride=self.stride,
117
+ padding=self.padding,
118
+ )
119
+
120
+ return out
121
+
122
+ def __repr__(self):
123
+ return (
124
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
125
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
126
+ )
127
+
128
+
129
+ class EqualLinear(nn.Module):
130
+ def __init__(
131
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
132
+ ):
133
+ super().__init__()
134
+
135
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
136
+
137
+ if bias:
138
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
139
+
140
+ else:
141
+ self.bias = None
142
+
143
+ self.activation = activation
144
+
145
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146
+ self.lr_mul = lr_mul
147
+
148
+ def forward(self, input):
149
+ if self.activation:
150
+ out = F.linear(input, self.weight * self.scale)
151
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
152
+
153
+ else:
154
+ out = F.linear(
155
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
156
+ )
157
+
158
+ return out
159
+
160
+ def __repr__(self):
161
+ return (
162
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
163
+ )
164
+
165
+
166
+ class ScaledLeakyReLU(nn.Module):
167
+ def __init__(self, negative_slope=0.2):
168
+ super().__init__()
169
+
170
+ self.negative_slope = negative_slope
171
+
172
+ def forward(self, input):
173
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
174
+
175
+ return out * math.sqrt(2)
176
+
177
+
178
+ class ModulatedConv2d(nn.Module):
179
+ def __init__(
180
+ self,
181
+ in_channel,
182
+ out_channel,
183
+ kernel_size,
184
+ style_dim,
185
+ demodulate=True,
186
+ upsample=False,
187
+ downsample=False,
188
+ blur_kernel=[1, 3, 3, 1],
189
+ ):
190
+ super().__init__()
191
+
192
+ self.eps = 1e-8
193
+ self.kernel_size = kernel_size
194
+ self.in_channel = in_channel
195
+ self.out_channel = out_channel
196
+ self.upsample = upsample
197
+ self.downsample = downsample
198
+
199
+ if upsample:
200
+ factor = 2
201
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
202
+ pad0 = (p + 1) // 2 + factor - 1
203
+ pad1 = p // 2 + 1
204
+
205
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
206
+
207
+ if downsample:
208
+ factor = 2
209
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
210
+ pad0 = (p + 1) // 2
211
+ pad1 = p // 2
212
+
213
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
214
+
215
+ fan_in = in_channel * kernel_size ** 2
216
+ self.scale = 1 / math.sqrt(fan_in)
217
+ self.padding = kernel_size // 2
218
+
219
+ self.weight = nn.Parameter(
220
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
221
+ )
222
+
223
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
224
+
225
+ self.demodulate = demodulate
226
+
227
+ def __repr__(self):
228
+ return (
229
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
230
+ f'upsample={self.upsample}, downsample={self.downsample})'
231
+ )
232
+
233
+ def forward(self, input, style):
234
+ batch, in_channel, height, width = input.shape
235
+
236
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
237
+ weight = self.scale * self.weight * style
238
+
239
+ if self.demodulate:
240
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
241
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
242
+
243
+ weight = weight.view(
244
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
245
+ )
246
+
247
+ if self.upsample:
248
+ input = input.view(1, batch * in_channel, height, width)
249
+ weight = weight.view(
250
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
251
+ )
252
+ weight = weight.transpose(1, 2).reshape(
253
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
254
+ )
255
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
256
+ _, _, height, width = out.shape
257
+ out = out.view(batch, self.out_channel, height, width)
258
+ out = self.blur(out)
259
+
260
+ elif self.downsample:
261
+ input = self.blur(input)
262
+ _, _, height, width = input.shape
263
+ input = input.view(1, batch * in_channel, height, width)
264
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
265
+ _, _, height, width = out.shape
266
+ out = out.view(batch, self.out_channel, height, width)
267
+
268
+ else:
269
+ input = input.view(1, batch * in_channel, height, width)
270
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
271
+ _, _, height, width = out.shape
272
+ out = out.view(batch, self.out_channel, height, width)
273
+
274
+ return out
275
+
276
+
277
+ class NoiseInjection(nn.Module):
278
+ def __init__(self):
279
+ super().__init__()
280
+
281
+ self.weight = nn.Parameter(torch.zeros(1))
282
+
283
+ def forward(self, image, noise=None):
284
+ if noise is None:
285
+ batch, _, height, width = image.shape
286
+ noise = image.new_empty(batch, 1, height, width).normal_()
287
+
288
+ return image + self.weight * noise
289
+
290
+
291
+ class ConstantInput(nn.Module):
292
+ def __init__(self, channel, size=4):
293
+ super().__init__()
294
+
295
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
296
+
297
+ def forward(self, input):
298
+ batch = input.shape[0]
299
+ out = self.input.repeat(batch, 1, 1, 1)
300
+
301
+ return out
302
+
303
+
304
+ class StyledConv(nn.Module):
305
+ def __init__(
306
+ self,
307
+ in_channel,
308
+ out_channel,
309
+ kernel_size,
310
+ style_dim,
311
+ upsample=False,
312
+ blur_kernel=[1, 3, 3, 1],
313
+ demodulate=True,
314
+ ):
315
+ super().__init__()
316
+
317
+ self.conv = ModulatedConv2d(
318
+ in_channel,
319
+ out_channel,
320
+ kernel_size,
321
+ style_dim,
322
+ upsample=upsample,
323
+ blur_kernel=blur_kernel,
324
+ demodulate=demodulate,
325
+ )
326
+
327
+ self.noise = NoiseInjection()
328
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
329
+ # self.activate = ScaledLeakyReLU(0.2)
330
+ self.activate = FusedLeakyReLU(out_channel)
331
+
332
+ def forward(self, input, style, noise=None):
333
+ out = self.conv(input, style)
334
+ out = self.noise(out, noise=noise)
335
+ # out = out + self.bias
336
+ out = self.activate(out)
337
+
338
+ return out
339
+
340
+
341
+ class ToRGB(nn.Module):
342
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
343
+ super().__init__()
344
+
345
+ if upsample:
346
+ self.upsample = Upsample(blur_kernel)
347
+
348
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
349
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
350
+
351
+ def forward(self, input, style, skip=None):
352
+ out = self.conv(input, style)
353
+ out = out + self.bias
354
+
355
+ if skip is not None:
356
+ skip = self.upsample(skip)
357
+
358
+ out = out + skip
359
+
360
+ return out
361
+
362
+
363
+ class Generator(nn.Module):
364
+ def __init__(
365
+ self,
366
+ size,
367
+ style_dim,
368
+ n_mlp,
369
+ channel_multiplier=2,
370
+ blur_kernel=[1, 3, 3, 1],
371
+ lr_mlp=0.01,
372
+ ):
373
+ super().__init__()
374
+
375
+ self.size = size
376
+
377
+ self.style_dim = style_dim
378
+
379
+ layers = [PixelNorm()]
380
+
381
+ for i in range(n_mlp):
382
+ layers.append(
383
+ EqualLinear(
384
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
385
+ )
386
+ )
387
+
388
+ self.style = nn.Sequential(*layers)
389
+
390
+ self.channels = {
391
+ 4: 512,
392
+ 8: 512,
393
+ 16: 512,
394
+ 32: 512,
395
+ 64: 256 * channel_multiplier,
396
+ 128: 128 * channel_multiplier,
397
+ 256: 64 * channel_multiplier,
398
+ 512: 32 * channel_multiplier,
399
+ 1024: 16 * channel_multiplier,
400
+ }
401
+
402
+ self.input = ConstantInput(self.channels[4])
403
+ self.conv1 = StyledConv(
404
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
405
+ )
406
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
407
+
408
+ self.log_size = int(math.log(size, 2))
409
+ self.num_layers = (self.log_size - 2) * 2 + 1
410
+
411
+ self.convs = nn.ModuleList()
412
+ self.upsamples = nn.ModuleList()
413
+ self.to_rgbs = nn.ModuleList()
414
+ self.noises = nn.Module()
415
+
416
+ in_channel = self.channels[4]
417
+
418
+ for layer_idx in range(self.num_layers):
419
+ res = (layer_idx + 5) // 2
420
+ shape = [1, 1, 2 ** res, 2 ** res]
421
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
422
+
423
+ for i in range(3, self.log_size + 1):
424
+ out_channel = self.channels[2 ** i]
425
+
426
+ self.convs.append(
427
+ StyledConv(
428
+ in_channel,
429
+ out_channel,
430
+ 3,
431
+ style_dim,
432
+ upsample=True,
433
+ blur_kernel=blur_kernel,
434
+ )
435
+ )
436
+
437
+ self.convs.append(
438
+ StyledConv(
439
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
440
+ )
441
+ )
442
+
443
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
444
+
445
+ in_channel = out_channel
446
+
447
+ self.n_latent = self.log_size * 2 - 2
448
+
449
+ def make_noise(self):
450
+ device = self.input.input.device
451
+
452
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
453
+
454
+ for i in range(3, self.log_size + 1):
455
+ for _ in range(2):
456
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
457
+
458
+ return noises
459
+
460
+ def mean_latent(self, n_latent):
461
+ latent_in = torch.randn(
462
+ n_latent, self.style_dim, device=self.input.input.device
463
+ )
464
+ latent = self.style(latent_in).mean(0, keepdim=True)
465
+
466
+ return latent
467
+
468
+ def get_latent(self, input):
469
+ return self.style(input)
470
+
471
+ def forward(
472
+ self,
473
+ styles,
474
+ return_latents=False,
475
+ return_features=False,
476
+ inject_index=None,
477
+ truncation=1,
478
+ truncation_latent=None,
479
+ input_is_latent=False,
480
+ noise=None,
481
+ randomize_noise=True,
482
+ ):
483
+ if not input_is_latent:
484
+ styles = [self.style(s) for s in styles]
485
+
486
+ if noise is None:
487
+ if randomize_noise:
488
+ noise = [None] * self.num_layers
489
+ else:
490
+ noise = [
491
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
492
+ ]
493
+
494
+ if truncation < 1:
495
+ style_t = []
496
+
497
+ for style in styles:
498
+ style_t.append(
499
+ truncation_latent + truncation * (style - truncation_latent)
500
+ )
501
+
502
+ styles = style_t
503
+
504
+ if len(styles) < 2:
505
+ inject_index = self.n_latent
506
+
507
+ if styles[0].ndim < 3:
508
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
509
+ else:
510
+ latent = styles[0]
511
+
512
+ else:
513
+ if inject_index is None:
514
+ inject_index = random.randint(1, self.n_latent - 1)
515
+
516
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
517
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
518
+
519
+ latent = torch.cat([latent, latent2], 1)
520
+
521
+ out = self.input(latent)
522
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
523
+
524
+ skip = self.to_rgb1(out, latent[:, 1])
525
+
526
+ i = 1
527
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
528
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
529
+ ):
530
+ out = conv1(out, latent[:, i], noise=noise1)
531
+ out = conv2(out, latent[:, i + 1], noise=noise2)
532
+ skip = to_rgb(out, latent[:, i + 2], skip)
533
+
534
+ i += 2
535
+
536
+ image = skip
537
+
538
+ if return_latents:
539
+ return image, latent
540
+ elif return_features:
541
+ return image, out
542
+ else:
543
+ return image, None
544
+
545
+
546
+ class ConvLayer(nn.Sequential):
547
+ def __init__(
548
+ self,
549
+ in_channel,
550
+ out_channel,
551
+ kernel_size,
552
+ downsample=False,
553
+ blur_kernel=[1, 3, 3, 1],
554
+ bias=True,
555
+ activate=True,
556
+ ):
557
+ layers = []
558
+
559
+ if downsample:
560
+ factor = 2
561
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
562
+ pad0 = (p + 1) // 2
563
+ pad1 = p // 2
564
+
565
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
566
+
567
+ stride = 2
568
+ self.padding = 0
569
+
570
+ else:
571
+ stride = 1
572
+ self.padding = kernel_size // 2
573
+
574
+ layers.append(
575
+ EqualConv2d(
576
+ in_channel,
577
+ out_channel,
578
+ kernel_size,
579
+ padding=self.padding,
580
+ stride=stride,
581
+ bias=bias and not activate,
582
+ )
583
+ )
584
+
585
+ if activate:
586
+ if bias:
587
+ layers.append(FusedLeakyReLU(out_channel))
588
+
589
+ else:
590
+ layers.append(ScaledLeakyReLU(0.2))
591
+
592
+ super().__init__(*layers)
593
+
594
+
595
+ class ResBlock(nn.Module):
596
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
597
+ super().__init__()
598
+
599
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
600
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
601
+
602
+ self.skip = ConvLayer(
603
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
604
+ )
605
+
606
+ def forward(self, input):
607
+ out = self.conv1(input)
608
+ out = self.conv2(out)
609
+
610
+ skip = self.skip(input)
611
+ out = (out + skip) / math.sqrt(2)
612
+
613
+ return out
614
+
615
+
616
+ class Discriminator(nn.Module):
617
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
618
+ super().__init__()
619
+
620
+ channels = {
621
+ 4: 512,
622
+ 8: 512,
623
+ 16: 512,
624
+ 32: 512,
625
+ 64: 256 * channel_multiplier,
626
+ 128: 128 * channel_multiplier,
627
+ 256: 64 * channel_multiplier,
628
+ 512: 32 * channel_multiplier,
629
+ 1024: 16 * channel_multiplier,
630
+ }
631
+
632
+ convs = [ConvLayer(3, channels[size], 1)]
633
+
634
+ log_size = int(math.log(size, 2))
635
+
636
+ in_channel = channels[size]
637
+
638
+ for i in range(log_size, 2, -1):
639
+ out_channel = channels[2 ** (i - 1)]
640
+
641
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
642
+
643
+ in_channel = out_channel
644
+
645
+ self.convs = nn.Sequential(*convs)
646
+
647
+ self.stddev_group = 4
648
+ self.stddev_feat = 1
649
+
650
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
651
+ self.final_linear = nn.Sequential(
652
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
653
+ EqualLinear(channels[4], 1),
654
+ )
655
+
656
+ def forward(self, input):
657
+ out = self.convs(input)
658
+
659
+ batch, channel, height, width = out.shape
660
+ group = min(batch, self.stddev_group)
661
+ stddev = out.view(
662
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
663
+ )
664
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
665
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
666
+ stddev = stddev.repeat(group, 1, height, width)
667
+ out = torch.cat([out, stddev], 1)
668
+
669
+ out = self.final_conv(out)
670
+
671
+ out = out.view(batch, -1)
672
+ out = self.final_linear(out)
673
+
674
+ return out
models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (255 Bytes). View file
 
models/stylegan2/op/__pycache__/fused_act.cpython-39.pyc ADDED
Binary file (2.83 kB). View file
 
models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+ # from . import fused
8
+
9
+ module_path = os.path.dirname(__file__)
10
+ fused = load(
11
+ 'fused',
12
+ sources=[
13
+ os.path.join(module_path, 'fused_bias_act.cpp'),
14
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
+ ],
16
+ )
17
+
18
+
19
+ class FusedLeakyReLUFunctionBackward(Function):
20
+ @staticmethod
21
+ def forward(ctx, grad_output, out, negative_slope, scale):
22
+ ctx.save_for_backward(out)
23
+ ctx.negative_slope = negative_slope
24
+ ctx.scale = scale
25
+
26
+ empty = grad_output.new_empty(0)
27
+
28
+ grad_input = fused.fused_bias_act(
29
+ grad_output, empty, out, 3, 1, negative_slope, scale
30
+ )
31
+
32
+ dim = [0]
33
+
34
+ if grad_input.ndim > 2:
35
+ dim += list(range(2, grad_input.ndim))
36
+
37
+ grad_bias = grad_input.sum(dim).detach()
38
+
39
+ return grad_input, grad_bias
40
+
41
+ @staticmethod
42
+ def backward(ctx, gradgrad_input, gradgrad_bias):
43
+ out, = ctx.saved_tensors
44
+ gradgrad_out = fused.fused_bias_act(
45
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
+ )
47
+
48
+ return gradgrad_out, None, None, None
49
+
50
+
51
+ class FusedLeakyReLUFunction(Function):
52
+ @staticmethod
53
+ def forward(ctx, input, bias, negative_slope, scale):
54
+ empty = input.new_empty(0)
55
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
+ ctx.save_for_backward(out)
57
+ ctx.negative_slope = negative_slope
58
+ ctx.scale = scale
59
+
60
+ return out
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ out, = ctx.saved_tensors
65
+
66
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
+ grad_output, out, ctx.negative_slope, ctx.scale
68
+ )
69
+
70
+ return grad_input, grad_bias, None, None
71
+
72
+
73
+ class FusedLeakyReLU(nn.Module):
74
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75
+ super().__init__()
76
+
77
+ self.bias = nn.Parameter(torch.zeros(channel))
78
+ self.negative_slope = negative_slope
79
+ self.scale = scale
80
+
81
+ def forward(self, input):
82
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83
+
84
+
85
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
models/stylegan2/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
models/stylegan2/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
models/stylegan2/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.nn import functional as F
6
+ # from . import upfirdn2d as upfirdn2d_op
7
+ from torch.utils.cpp_extension import load
8
+
9
+ module_path = os.path.dirname(__file__)
10
+ upfirdn2d_op = load(
11
+ 'upfirdn2d',
12
+ sources=[
13
+ os.path.join(module_path, 'upfirdn2d.cpp'),
14
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
15
+ ],
16
+ )
17
+
18
+
19
+ class UpFirDn2dBackward(Function):
20
+ @staticmethod
21
+ def forward(
22
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23
+ ):
24
+ up_x, up_y = up
25
+ down_x, down_y = down
26
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
+
28
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
+
30
+ grad_input = upfirdn2d_op.upfirdn2d(
31
+ grad_output,
32
+ grad_kernel,
33
+ down_x,
34
+ down_y,
35
+ up_x,
36
+ up_y,
37
+ g_pad_x0,
38
+ g_pad_x1,
39
+ g_pad_y0,
40
+ g_pad_y1,
41
+ )
42
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
+
44
+ ctx.save_for_backward(kernel)
45
+
46
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
+
48
+ ctx.up_x = up_x
49
+ ctx.up_y = up_y
50
+ ctx.down_x = down_x
51
+ ctx.down_y = down_y
52
+ ctx.pad_x0 = pad_x0
53
+ ctx.pad_x1 = pad_x1
54
+ ctx.pad_y0 = pad_y0
55
+ ctx.pad_y1 = pad_y1
56
+ ctx.in_size = in_size
57
+ ctx.out_size = out_size
58
+
59
+ return grad_input
60
+
61
+ @staticmethod
62
+ def backward(ctx, gradgrad_input):
63
+ kernel, = ctx.saved_tensors
64
+
65
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
+
67
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
68
+ gradgrad_input,
69
+ kernel,
70
+ ctx.up_x,
71
+ ctx.up_y,
72
+ ctx.down_x,
73
+ ctx.down_y,
74
+ ctx.pad_x0,
75
+ ctx.pad_x1,
76
+ ctx.pad_y0,
77
+ ctx.pad_y1,
78
+ )
79
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
+ gradgrad_out = gradgrad_out.view(
81
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
+ )
83
+
84
+ return gradgrad_out, None, None, None, None, None, None, None, None
85
+
86
+
87
+ class UpFirDn2d(Function):
88
+ @staticmethod
89
+ def forward(ctx, input, kernel, up, down, pad):
90
+ up_x, up_y = up
91
+ down_x, down_y = down
92
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
+
94
+ kernel_h, kernel_w = kernel.shape
95
+ batch, channel, in_h, in_w = input.shape
96
+ ctx.in_size = input.shape
97
+
98
+ input = input.reshape(-1, in_h, in_w, 1)
99
+
100
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
+
102
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
+ ctx.out_size = (out_h, out_w)
105
+
106
+ ctx.up = (up_x, up_y)
107
+ ctx.down = (down_x, down_y)
108
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
+
110
+ g_pad_x0 = kernel_w - pad_x0 - 1
111
+ g_pad_y0 = kernel_h - pad_y0 - 1
112
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
+
115
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
+
117
+ out = upfirdn2d_op.upfirdn2d(
118
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
+ )
120
+ # out = out.view(major, out_h, out_w, minor)
121
+ out = out.view(-1, channel, out_h, out_w)
122
+
123
+ return out
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ kernel, grad_kernel = ctx.saved_tensors
128
+
129
+ grad_input = UpFirDn2dBackward.apply(
130
+ grad_output,
131
+ kernel,
132
+ grad_kernel,
133
+ ctx.up,
134
+ ctx.down,
135
+ ctx.pad,
136
+ ctx.g_pad,
137
+ ctx.in_size,
138
+ ctx.out_size,
139
+ )
140
+
141
+ return grad_input, None, None, None, None
142
+
143
+
144
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
+ out = UpFirDn2d.apply(
146
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
+ )
148
+
149
+ return out
150
+
151
+
152
+ def upfirdn2d_native(
153
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
+ ):
155
+ _, in_h, in_w, minor = input.shape
156
+ kernel_h, kernel_w = kernel.shape
157
+
158
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
159
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161
+
162
+ out = F.pad(
163
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164
+ )
165
+ out = out[
166
+ :,
167
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
168
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
169
+ :,
170
+ ]
171
+
172
+ out = out.permute(0, 3, 1, 2)
173
+ out = out.reshape(
174
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175
+ )
176
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177
+ out = F.conv2d(out, w)
178
+ out = out.reshape(
179
+ -1,
180
+ minor,
181
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
+ )
184
+ out = out.permute(0, 2, 3, 1)
185
+
186
+ return out[:, ::down_y, ::down_x, :]
models/stylegan2/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }
models/stylegene/__init__.py ADDED
File without changes
models/stylegene/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
models/stylegene/__pycache__/api.cpython-39.pyc ADDED
Binary file (3.28 kB). View file
 
models/stylegene/api.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from models.stylegan2.model import Generator
5
+ from models.encoders.psp_encoders import Encoder4Editing
6
+ from models.stylegene.model import MappingSub2W, MappingW2Sub
7
+ from models.stylegene.util import get_keys, requires_grad, load_img
8
+ from models.stylegene.gene_pool import GenePoolFactory
9
+ from models.stylegene.gene_crossover_mutation import fuse_latent
10
+ from models.stylegene.fair_face_model import init_fair_model, predict_race
11
+ from configs import path_ckpt_e4e, path_ckpt_stylegan2, path_ckpt_stylegene, path_ckpt_genepool, path_dataset_ffhq
12
+ from preprocess.align_images import align_face
13
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
14
+
15
+
16
+ def init_model(image_size=1024, latent_dim=512):
17
+ ckp = torch.load(path_ckpt_e4e, map_location='cpu')
18
+ encoder = Encoder4Editing(50, 'ir_se', image_size).eval()
19
+ encoder.load_state_dict(get_keys(ckp, 'encoder'), strict=True)
20
+ mean_latent = ckp['latent_avg'].to('cpu')
21
+ mean_latent.unsqueeze_(0)
22
+
23
+ generator = Generator(image_size, latent_dim, 8)
24
+ checkpoint = torch.load(path_ckpt_stylegan2, map_location='cpu')
25
+ generator.load_state_dict(checkpoint["g_ema"], strict=False)
26
+ generator.eval()
27
+ sub2w = MappingSub2W(N=18).eval()
28
+ w2sub34 = MappingW2Sub(N=18).eval()
29
+ ckp = torch.load(path_ckpt_stylegene, map_location='cpu')
30
+ w2sub34.load_state_dict(get_keys(ckp, 'w2sub34'))
31
+ sub2w.load_state_dict(get_keys(ckp, 'sub2w'))
32
+
33
+ requires_grad(sub2w, False)
34
+ requires_grad(w2sub34, False)
35
+ requires_grad(encoder, False)
36
+ requires_grad(generator, False)
37
+ return encoder, generator, sub2w, w2sub34, mean_latent
38
+
39
+
40
+ # init model
41
+ encoder, generator, sub2w, w2sub34, mean_latent = init_model()
42
+ encoder, generator, sub2w, w2sub34, mean_latent = encoder.to(device), generator.to(device), sub2w.to(
43
+ device), w2sub34.to(device), mean_latent.to(device)
44
+ model_fair_7 = init_fair_model(device) # init FairFace model
45
+
46
+ # load a GenePool
47
+ geneFactor = GenePoolFactory(root_ffhq=path_dataset_ffhq, device=device, mean_latent=mean_latent, max_sample=300)
48
+ geneFactor.pools = torch.load(path_ckpt_genepool)
49
+ print("gene pool loaded!")
50
+
51
+
52
+ def tensor2rgb(tensor):
53
+ tensor = (tensor * 0.5 + 0.5) * 255
54
+ tensor = torch.clip(tensor, 0, 255).squeeze(0)
55
+ tensor = tensor.detach().cpu().numpy().transpose(1, 2, 0)
56
+ tensor = tensor.astype(np.uint8)
57
+ return tensor
58
+
59
+
60
+ def generate_child(w18_F, w18_M, random_fakes, gamma=0.46, eta=0.4):
61
+ w18_syn = fuse_latent(w2sub34, sub2w, w18_F=w18_F, w18_M=w18_M,
62
+ random_fakes=random_fakes, fixed_gamma=gamma, fixed_eta=eta)
63
+
64
+ img_C, _ = generator([w18_syn], return_latents=True, input_is_latent=True)
65
+ return img_C, w18_syn
66
+
67
+
68
+ def synthesize_descendant(pF, pM, attributes=None):
69
+ gender_all = ['male', 'female']
70
+ ages_all = ['0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', '70+']
71
+ if attributes is None:
72
+ attributes = {'age': ages_all[0], 'gender': gender_all[0], 'gamma': 0.47, 'eta': 0.4}
73
+ imgF = align_face(pF)
74
+ imgM = align_face(pM)
75
+ imgF = load_img(imgF)
76
+ imgM = load_img(imgM)
77
+ imgF, imgM = imgF.to(device), imgM.to(device)
78
+
79
+ father_race, _, _, _ = predict_race(model_fair_7, imgF.clone(), imgF.device)
80
+ mother_race, _, _, _ = predict_race(model_fair_7, imgM.clone(), imgM.device)
81
+
82
+ w18_1 = encoder(F.interpolate(imgF, size=(256, 256))) + mean_latent
83
+ w18_2 = encoder(F.interpolate(imgM, size=(256, 256))) + mean_latent
84
+
85
+ random_fakes = []
86
+ for r in list({father_race, mother_race}): # search RFGs from Gene Pool
87
+ random_fakes = random_fakes + geneFactor(encoder, w2sub34, attributes['age'], attributes['gender'], r)
88
+ img_C, w18_syn = generate_child(w18_1.clone(), w18_2.clone(), random_fakes,
89
+ gamma=attributes['gamma'], eta=attributes['eta'])
90
+ img_C = tensor2rgb(img_C)
91
+ img_F = tensor2rgb(imgF)
92
+ img_M = tensor2rgb(imgM)
93
+
94
+ return img_F, img_M, img_C
models/stylegene/data_util.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2021 NVIDIA Corporation. All rights reserved.
3
+ Licensed under The MIT License (MIT)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
6
+ this software and associated documentation files (the "Software"), to deal in
7
+ the Software without restriction, including without limitation the rights to
8
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9
+ the Software, and to permit persons to whom the Software is furnished to do so,
10
+ subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
+ """
22
+
23
+ face_class = ['background', 'head', 'head***cheek', 'head***chin', 'head***ear', 'head***ear***helix',
24
+ 'head***ear***lobule', 'head***eye***botton lid', 'head***eye***eyelashes', 'head***eye***iris',
25
+ 'head***eye***pupil', 'head***eye***sclera', 'head***eye***tear duct', 'head***eye***top lid',
26
+ 'head***eyebrow', 'head***forehead', 'head***frown', 'head***hair', 'head***hair***sideburns',
27
+ 'head***jaw', 'head***moustache', 'head***mouth***inferior lip', 'head***mouth***oral comisure',
28
+ 'head***mouth***superior lip', 'head***mouth***teeth', 'head***neck', 'head***nose',
29
+ 'head***nose***ala of nose', 'head***nose***bridge', 'head***nose***nose tip', 'head***nose***nostril',
30
+ 'head***philtrum', 'head***temple', 'head***wrinkles']
31
+ face_must = ['head***forehead', 'head***frown', 'head***nose***bridge', 'head***nose', 'head***nose***ala of nose',
32
+ 'head***nose***nose tip', 'head***nose***nostril', 'head***mouth***inferior lip',
33
+ 'head***mouth***superior lip', 'head***chin', 'head***eye***top lid', 'head***eye***pupil',
34
+ 'head***eye***iris', 'head***eye***tear duct']
35
+
36
+ face_shape = ['head', 'head***chin', 'head***jaw', 'head***moustache', 'head***cheek']
models/stylegene/fair_face_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torch import nn
5
+ import torchvision
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+ from configs import path_ckpt_fairface
9
+
10
+ # code adapted from https://github.com/dchen236/FairFace
11
+
12
+ def init_fair_model(device, path_ckpt=None):
13
+ if path_ckpt is None:
14
+ path_ckpt = path_ckpt_fairface
15
+ model_fair_7 = torchvision.models.resnet34(pretrained=False)
16
+ model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
17
+ model_fair_7.load_state_dict(
18
+ torch.load(path_ckpt))
19
+ model_fair_7 = model_fair_7.to(device)
20
+ model_fair_7.eval()
21
+ return model_fair_7
22
+
23
+
24
+ def predict_race(model_fair_7, path_img, device):
25
+ if type(path_img) == str:
26
+ trans = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+ image = Image.open(path_img)
32
+ image = trans(image)
33
+ image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size)
34
+ elif type(path_img) == torch.Tensor:
35
+ trans = transforms.Compose([
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
+ ])
38
+ image = F.interpolate(path_img, (224, 224))
39
+ image = image * 0.5 + 0.5
40
+ image = trans(image)
41
+ image = image.view(1, 3, 224, 224)
42
+
43
+ image = image.to(device)
44
+
45
+ outputs = model_fair_7(image)
46
+ outputs = outputs.cpu().detach().numpy()
47
+ outputs = np.squeeze(outputs)
48
+
49
+ race_outputs = outputs[:7]
50
+ gender_outputs = outputs[7:9]
51
+ age_outputs = outputs[9:18]
52
+
53
+ race_score = np.exp(race_outputs) / np.sum(np.exp(race_outputs))
54
+ gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
55
+ age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))
56
+
57
+ race_pred = np.argmax(race_score)
58
+ gender_pred = np.argmax(gender_score)
59
+ age_pred = np.argmax(age_score)
60
+ race_label = ['White', 'Black', 'Latino_Hispanic', 'East Asian', 'Southeast Asian', 'Indian', 'Middle Eastern']
61
+ return race_label[race_pred], race_pred, gender_pred, age_pred
models/stylegene/gene_crossover_mutation.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .data_util import face_class, face_shape
3
+ import random
4
+
5
+
6
+ def reparameterize(mu, logvar):
7
+ """
8
+ Reparameterization trick to sample from N(mu, var) from
9
+ N(0,1).
10
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
11
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
12
+ :return: (Tensor) [B x D]
13
+ """
14
+ std = torch.exp(0.5 * logvar)
15
+ eps = torch.randn_like(std)
16
+
17
+ return eps * std + mu
18
+
19
+
20
+ def mix(w18_F, w18_M, w18_syn):
21
+ for k in [8, 9, 10, 11, 12, 13, 14, 15, 16, 17]:
22
+ w18_syn[:, k, :] = w18_F[:, k, :] * 0.5 + w18_M[:, k, :] * 0.5
23
+ return w18_syn
24
+
25
+
26
+ def fuse_latent(w2sub34, sub2w, w18_F, w18_M, random_fakes, fixed_gamma=0.47, fixed_eta=0.4):
27
+ device = w18_F.device
28
+
29
+ mu_F, var_F, sub34_F = w2sub34(w18_F)
30
+ mu_M, var_M, sub34_M = w2sub34(w18_M)
31
+ new_sub34 = torch.zeros_like(sub34_F, dtype=torch.float, device=device)
32
+
33
+ if len(random_fakes) == 0: # EXCEPTION HANDLER (No matching gene pool)
34
+ random_fakes = [(mu_F.cpu(), var_F.cpu())] + [(mu_M.cpu(), var_M.cpu())]
35
+
36
+ # region genetic variation weights
37
+ weights = {}
38
+ for i in face_class:
39
+ weights[i] = (random.uniform(0, 1 - float(fixed_gamma)), float(fixed_gamma))
40
+
41
+ # select genetic regions
42
+ cur_class = random.sample(face_class, int(len(face_class) * (1 - float(fixed_eta))))
43
+
44
+ for i, classname in enumerate(face_class):
45
+ if classname == 'background':
46
+ new_sub34[:, :, i, :] = reparameterize(mu_F[:, :, i, :], var_F[:, :, i, :])
47
+ continue
48
+
49
+ if classname in cur_class: # # corresponding to t = 0 in Eq.10
50
+ fake_mu, fake_var = random.choice(random_fakes)
51
+ w_i, b_i = weights[classname]
52
+ new_sub34[:, :, i, :] = reparameterize(
53
+ mu_F[:, :, i, :] * w_i + fake_mu[:, :, i, :].to(device) * b_i + mu_M[:, :, i, :] * (1 - w_i - b_i),
54
+ var_F[:, :, i, :] * w_i + fake_var[:, :, i, :].to(device) * b_i + var_M[:, :, i, :] * (1 - w_i - b_i))
55
+ else: # corresponding to t = 1 in Eq.10
56
+ fake_mu, fake_var = random.choice(random_fakes)
57
+ fake_latent = reparameterize(fake_mu[:, :, i, :], fake_var[:, :, i, :]).to(device)
58
+ var = fake_latent
59
+ new_sub34[:, :, i, :] = new_sub34[:, :, i, :] + var
60
+ w18_syn = sub2w(new_sub34)
61
+
62
+ w18_syn = mix(w18_F, w18_M, w18_syn)
63
+
64
+ return w18_syn
models/stylegene/gene_pool.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import torch.nn.functional as F
5
+
6
+ from .util import load_img
7
+ from configs import path_csv_ffhq_attritube
8
+
9
+
10
+ class GenePoolFactory(object):
11
+ def __init__(self, root_ffhq, device, mean_latent, max_sample=100):
12
+ self.device = device
13
+ self.mean_latent = mean_latent
14
+ self.root_ffhq = root_ffhq
15
+ self.max_sample = max_sample
16
+
17
+ self.pools = {}
18
+ path_ffhq_attributes = path_csv_ffhq_attritube
19
+ self.df = pd.read_csv(path_ffhq_attributes)
20
+ self.df.replace('Male', 'male', inplace=True)
21
+ self.df.replace('Female', 'female', inplace=True)
22
+
23
+ def __call__(self, encoder, w2sub34, age, gender, race):
24
+ keyname = f'{age}-{gender}-{race}'
25
+ if keyname in self.pools.keys():
26
+ return self.pools[keyname]
27
+ elif self.root_ffhq is not None:
28
+ result = self.df.query(f'gender == "{gender}" and age == "{age}" and race == "{race}"')
29
+ result = result[['file_id']].values
30
+ tmp = []
31
+ random.shuffle(result)
32
+ for fid in result[:self.max_sample]:
33
+ filename = format(int(fid[0]), '05d') + ".png"
34
+ img = load_img(os.path.join(self.root_ffhq, filename))
35
+ img = img.to(self.device)
36
+ w18_1 = encoder(F.interpolate(img, size=(256, 256))) + self.mean_latent
37
+ mu, var, sub34_1 = w2sub34(w18_1)
38
+ tmp.append((mu.cpu(), var.cpu()))
39
+ self.pools[keyname] = tmp
40
+ return self.pools[keyname]
41
+ else:
42
+ return []
models/stylegene/model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from functools import partial
4
+ from einops.layers.torch import Rearrange, Reduce
5
+ from einops import rearrange
6
+
7
+ pair = lambda x: x if isinstance(x, tuple) else (x, x)
8
+
9
+
10
+ class PreNormResidual(nn.Module):
11
+ def __init__(self, dim, fn):
12
+ super().__init__()
13
+ self.fn = fn
14
+ self.norm = nn.LayerNorm(dim)
15
+
16
+ def forward(self, x):
17
+ return self.fn(self.norm(x)) + x
18
+
19
+
20
+ def FeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
21
+ inner_dim = int(dim * expansion_factor)
22
+ return nn.Sequential(
23
+ dense(dim, inner_dim),
24
+ nn.GELU(),
25
+ nn.Dropout(dropout),
26
+ dense(inner_dim, dim),
27
+ nn.Dropout(dropout)
28
+ )
29
+
30
+
31
+ class MappingSub2W(nn.Module):
32
+ def __init__(self, N=8, dim=512, depth=6, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
33
+ super(MappingSub2W, self).__init__()
34
+ num_patches = N * 34
35
+
36
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
37
+ self.layer = nn.Sequential(
38
+ Rearrange('b c h w -> b (c h) w'),
39
+ *[nn.Sequential(
40
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
41
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
42
+ ) for _ in range(depth)],
43
+ nn.LayerNorm(dim),
44
+ Rearrange('b c h -> b h c'),
45
+ nn.Linear(34 * N, 34 * N),
46
+ nn.LayerNorm(34 * N),
47
+ nn.GELU(),
48
+ nn.Linear(34 * N, N),
49
+ Rearrange('b h c -> b c h')
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.layer(x)
54
+
55
+
56
+ class MappingW2Sub(nn.Module):
57
+ def __init__(self, N=8, dim=512, depth=8, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
58
+ super(MappingW2Sub, self).__init__()
59
+ self.N = N
60
+ num_patches = N * 34
61
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
62
+
63
+ self.layer = nn.Sequential(
64
+ Rearrange('b c h -> b h c'),
65
+ nn.Linear(N, num_patches),
66
+ Rearrange('b h c -> b c h'),
67
+ *[nn.Sequential(
68
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
69
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
70
+ ) for _ in range(depth)],
71
+ nn.LayerNorm(dim)
72
+ )
73
+ self.mu_fc = nn.Sequential(
74
+ *[nn.Sequential(
75
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
76
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
77
+ ) for _ in range(2)],
78
+ nn.LayerNorm(dim),
79
+ nn.Tanh(),
80
+ Rearrange('b c h -> b (c h)')
81
+ )
82
+ self.var_fc = nn.Sequential(
83
+ *[nn.Sequential(
84
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
85
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
86
+ ) for _ in range(2)],
87
+ nn.LayerNorm(dim),
88
+ nn.Tanh(),
89
+ Rearrange('b c h -> b (c h)')
90
+ )
91
+
92
+ def reparameterize(self, mu, logvar):
93
+ """
94
+ Reparameterization trick to sample from N(mu, var) from
95
+ N(0,1).
96
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
97
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
98
+ :return: (Tensor) [B x D]
99
+ """
100
+ std = torch.exp(0.5 * logvar)
101
+ eps = torch.randn_like(std)
102
+
103
+ return eps * std + mu
104
+
105
+ def forward(self, x):
106
+ f = self.layer(x)
107
+ mu = self.mu_fc(f)
108
+ var = self.var_fc(f)
109
+
110
+ z = self.reparameterize(mu, var)
111
+ z = rearrange(z, 'a (b c d) -> a b c d', b=self.N, c=34)
112
+ return rearrange(mu, 'a (b c d) -> a b c d', b=self.N, c=34), rearrange(var, 'a (b c d) -> a b c d',
113
+ b=self.N, c=34), z
114
+
115
+
116
+ class HeadEncoder(nn.Module):
117
+ def __init__(self, N=8, dim=512, depth=2, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
118
+ super(HeadEncoder, self).__init__()
119
+ channels = [32, 64, 64, 64]
120
+ self.N = N
121
+ num_patches = N
122
+ chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
123
+
124
+ self.s1 = nn.Sequential(
125
+ nn.Conv2d(channels[0], channels[1], kernel_size=5, padding=2, stride=2),
126
+ nn.BatchNorm2d(channels[1]),
127
+ nn.LeakyReLU(),
128
+ nn.Conv2d(channels[1], channels[2], kernel_size=5, padding=2, stride=2),
129
+ nn.BatchNorm2d(channels[2]),
130
+ nn.LeakyReLU(),
131
+ nn.Conv2d(channels[2], channels[3], kernel_size=5, padding=2, stride=2),
132
+ nn.BatchNorm2d(channels[3]),
133
+ nn.LeakyReLU())
134
+ self.mlp1 = nn.Linear(channels[3] * 8 * 8, 512)
135
+
136
+ self.up_N = nn.Linear(1, N)
137
+
138
+ self.mu_fc = nn.Sequential(
139
+ *[nn.Sequential(
140
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
141
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
142
+ ) for _ in range(depth)],
143
+ nn.LayerNorm(dim),
144
+ nn.Tanh()
145
+ )
146
+ self.var_fc = nn.Sequential(
147
+ *[nn.Sequential(
148
+ PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
149
+ PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
150
+ ) for _ in range(depth)],
151
+ nn.LayerNorm(dim),
152
+ nn.Tanh()
153
+ )
154
+
155
+ def reparameterize(self, mu, logvar):
156
+ """
157
+ Reparameterization trick to sample from N(mu, var) from
158
+ N(0,1).
159
+ :param mu: (Tensor) Mean of the latent Gaussian [B x D]
160
+ :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
161
+ :return: (Tensor) [B x D]
162
+ """
163
+ std = torch.exp(0.5 * logvar)
164
+ eps = torch.randn_like(std)
165
+ return eps * std + mu
166
+
167
+ def forward(self, x):
168
+ feature = self.s1(x)
169
+ s2 = torch.flatten(feature, start_dim=1)
170
+ s2 = self.mlp1(s2).unsqueeze(2)
171
+ s2 = self.up_N(s2)
172
+ s2 = rearrange(s2, 'b h c -> b c h')
173
+ mu = self.mu_fc(s2)
174
+ var = self.var_fc(s2)
175
+ z = self.reparameterize(mu, var)
176
+ return mu, var, z
177
+
178
+
179
+ class RegionEncoder(nn.Module):
180
+ def __init__(self, N=8):
181
+ super(RegionEncoder, self).__init__()
182
+ channels = [8, 16, 32, 32, 64, 64]
183
+ self.s1 = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, stride=2)
184
+ self.s2 = nn.Sequential(
185
+ nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, stride=2),
186
+ nn.BatchNorm2d(channels[1]),
187
+ nn.LeakyReLU(),
188
+ nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1, stride=2),
189
+ nn.BatchNorm2d(channels[2]),
190
+ nn.LeakyReLU()
191
+ )
192
+ self.heads = nn.ModuleList()
193
+ for i in range(34):
194
+ self.heads.append(HeadEncoder(N=N))
195
+
196
+ def forward(self, x, all_mask=None):
197
+ s1 = self.s1(x)
198
+ s2 = self.s2(s1)
199
+ result = []
200
+ mus = []
201
+ log_vars = []
202
+ for i, head in enumerate(self.heads):
203
+ m = all_mask[:, i, :].unsqueeze(1)
204
+ mu, var, z = head(s2 * m)
205
+ result.append(z.unsqueeze(2))
206
+ mus.append(mu.unsqueeze(2))
207
+ log_vars.append(var.unsqueeze(2))
208
+
209
+ return torch.cat(mus, dim=2), torch.cat(log_vars, dim=2), torch.cat(result, dim=2)
models/stylegene/util.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+
5
+
6
+ def requires_grad(model, flag=True):
7
+ for p in model.parameters():
8
+ p.requires_grad = flag
9
+
10
+
11
+ def get_keys(d, name):
12
+ if 'state_dict' in d:
13
+ d = d['state_dict']
14
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
15
+ return d_filt
16
+
17
+
18
+ def load_img(path_img, img_size=(256, 256)):
19
+ transform = transforms.Compose(
20
+ [transforms.Resize(img_size),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.5, 0.5, 0.5),
23
+ (0.5, 0.5, 0.5))])
24
+ if type(path_img) is np.ndarray:
25
+ img = Image.fromarray(path_img)
26
+ else:
27
+ img = Image.open(path_img).convert('RGB')
28
+ img = transform(img)
29
+ img.unsqueeze_(0)
30
+ return img
preprocess/__init__.py ADDED
File without changes
preprocess/align_images.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bz2
2
+ from .face_alignment import image_align
3
+ from .landmarks_detector import LandmarksDetector
4
+ from configs import path_ckpt_landmark68
5
+
6
+ LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
7
+
8
+
9
+ def unpack_bz2(src_path):
10
+ data = bz2.BZ2File(src_path).read()
11
+ dst_path = src_path[:-4]
12
+ with open(dst_path, 'wb') as fp:
13
+ fp.write(data)
14
+ return dst_path
15
+
16
+
17
+ def init_landmark():
18
+ landmarks_model_path = unpack_bz2(path_ckpt_landmark68)
19
+ landmarks_detector = LandmarksDetector(landmarks_model_path)
20
+ return landmarks_detector
21
+
22
+
23
+ landmarks_detector = init_landmark()
24
+
25
+
26
+ def align_face(raw_img, output_size=256):
27
+ try:
28
+ face_landmarks = landmarks_detector.get_landmarks(raw_img)[0]
29
+ aligned_face = image_align(raw_img, face_landmarks, output_size=output_size, transform_size=1024)
30
+ return aligned_face
31
+ except:
32
+ return raw_img
preprocess/face_alignment.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.ndimage
3
+ import os
4
+ import PIL.Image
5
+
6
+
7
+ def image_align(src, face_landmarks, output_size=1024, transform_size=4096, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1, alpha=False):
8
+ # Align function from FFHQ dataset pre-processing step
9
+ # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
10
+
11
+ lm = np.array(face_landmarks)
12
+ lm_chin = lm[0 : 17] # left-right
13
+ lm_eyebrow_left = lm[17 : 22] # left-right
14
+ lm_eyebrow_right = lm[22 : 27] # left-right
15
+ lm_nose = lm[27 : 31] # top-down
16
+ lm_nostrils = lm[31 : 36] # top-down
17
+ lm_eye_left = lm[36 : 42] # left-clockwise
18
+ lm_eye_right = lm[42 : 48] # left-clockwise
19
+ lm_mouth_outer = lm[48 : 60] # left-clockwise
20
+ lm_mouth_inner = lm[60 : 68] # left-clockwise
21
+
22
+ # Calculate auxiliary vectors.
23
+ eye_left = np.mean(lm_eye_left, axis=0)
24
+ eye_right = np.mean(lm_eye_right, axis=0)
25
+ eye_avg = (eye_left + eye_right) * 0.5
26
+ eye_to_eye = eye_right - eye_left
27
+ mouth_left = lm_mouth_outer[0]
28
+ mouth_right = lm_mouth_outer[6]
29
+ mouth_avg = (mouth_left + mouth_right) * 0.5
30
+ eye_to_mouth = mouth_avg - eye_avg
31
+
32
+ # Choose oriented crop rectangle.
33
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
34
+ x /= np.hypot(*x)
35
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
36
+ x *= x_scale
37
+ y = np.flipud(x) * [-y_scale, y_scale]
38
+ c = eye_avg + eye_to_mouth * em_scale
39
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
40
+ qsize = np.hypot(*x) * 2
41
+
42
+ # Load in-the-wild image.
43
+ img = PIL.Image.fromarray(src)
44
+
45
+ # Shrink.
46
+ shrink = int(np.floor(qsize / output_size * 0.5))
47
+ if shrink > 1:
48
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
49
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
50
+ quad /= shrink
51
+ qsize /= shrink
52
+
53
+ # Crop.
54
+ border = max(int(np.rint(qsize * 0.1)), 3)
55
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
56
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
57
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
58
+ img = img.crop(crop)
59
+ quad -= crop[0:2]
60
+
61
+ # Pad.
62
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
63
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
64
+ if enable_padding and max(pad) > border - 4:
65
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
66
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
67
+ h, w, _ = img.shape
68
+ y, x, _ = np.ogrid[:h, :w, :1]
69
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
70
+ blur = qsize * 0.02
71
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
72
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
73
+ img = np.uint8(np.clip(np.rint(img), 0, 255))
74
+ if alpha:
75
+ mask = 1-np.clip(3.0 * mask, 0.0, 1.0)
76
+ mask = np.uint8(np.clip(np.rint(mask*255), 0, 255))
77
+ img = np.concatenate((img, mask), axis=2)
78
+ img = PIL.Image.fromarray(img, 'RGBA')
79
+ else:
80
+ img = PIL.Image.fromarray(img, 'RGB')
81
+ quad += pad[:2]
82
+
83
+ # Transform.
84
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
85
+ if output_size < transform_size:
86
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
87
+ return np.array(img)
preprocess/landmarks_detector.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dlib
2
+
3
+
4
+ class LandmarksDetector:
5
+ def __init__(self, predictor_model_path):
6
+ """
7
+ :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
8
+ """
9
+ self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used
10
+ self.shape_predictor = dlib.shape_predictor(predictor_model_path)
11
+
12
+ def get_landmarks(self, img):
13
+ # img = dlib.load_rgb_image(image) # load from path
14
+ dets = self.detector(img, 1)
15
+
16
+ all_faces = []
17
+ for detection in dets:
18
+ try:
19
+ face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
20
+ all_faces.append(face_landmarks)
21
+ except:
22
+ print("Exception in get_landmarks()!")
23
+ return all_faces
24
+
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython==0.29.34
2
+ cytoolz==0.11.0
3
+ dlib==19.24.0
4
+ easydict==1.10
5
+ einops==0.3.2
6
+ gradio==3.32.0
7
+ gradio_client==0.2.5
8
+ huggingface-hub==0.14.1
9
+ omegaconf==2.2.3
10
+ onnx==1.12.0
11
+ onnxruntime==1.9.0
12
+ opencv-contrib-python==4.5.5.64
13
+ opencv-python==4.5.4.60
14
+ opencv-python-headless==4.7.0.72
15
+ pandas==1.4.4
16
+ Pillow==9.2.0
17
+ pytorch-lightning==1.8.1
18
+ torch==1.12.1
19
+ torchvision==0.13.1