hylee commited on
Commit
eb7d2bb
β€’
1 Parent(s): 545c8cd
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  title: Photo2cartoon
3
  emoji: πŸ‘
4
  colorFrom: gray
 
1
  ---
2
+ python_version: 3.7
3
  title: Photo2cartoon
4
  emoji: πŸ‘
5
  colorFrom: gray
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import functools
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ from typing import Callable
10
+
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+
17
+ import cv2
18
+
19
+ from io import BytesIO
20
+ sys.path.insert(0, 'p2c')
21
+
22
+ from test_onnx import Photo2Cartoon
23
+
24
+
25
+ ORIGINAL_REPO_URL = 'https://github.com/minivision-ai/photo2cartoon'
26
+ TITLE = 'minivision-ai/photo2cartoon'
27
+ DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
28
+
29
+ """
30
+ ARTICLE = """
31
+
32
+ """
33
+
34
+
35
+
36
+ def parse_args() -> argparse.Namespace:
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument('--device', type=str, default='cpu')
39
+ parser.add_argument('--theme', type=str)
40
+ parser.add_argument('--live', action='store_true')
41
+ parser.add_argument('--share', action='store_true')
42
+ parser.add_argument('--port', type=int)
43
+ parser.add_argument('--disable-queue',
44
+ dest='enable_queue',
45
+ action='store_false')
46
+ parser.add_argument('--allow-flagging', type=str, default='never')
47
+ parser.add_argument('--allow-screenshot', action='store_true')
48
+ return parser.parse_args()
49
+
50
+
51
+
52
+
53
+ def run(
54
+ image,
55
+ p2c : Photo2Cartoon,
56
+ ) -> tuple[PIL.Image.Image]:
57
+
58
+ cartoon = p2c.inference(image.name)
59
+
60
+ return PIL.Image.fromarray(cartoon)
61
+
62
+
63
+ def main():
64
+ gr.close_all()
65
+
66
+ args = parse_args()
67
+
68
+ p2c = Photo2Cartoon()
69
+
70
+ func = functools.partial(run, p2c)
71
+ func = functools.update_wrapper(func, run)
72
+
73
+
74
+ gr.Interface(
75
+ func,
76
+ [
77
+ gr.inputs.Image(type='file', label='Input Image'),
78
+ ],
79
+ [
80
+ gr.outputs.Image(
81
+ type='pil',
82
+ label='Result'),
83
+ ],
84
+ #examples=examples,
85
+ theme=args.theme,
86
+ title=TITLE,
87
+ description=DESCRIPTION,
88
+ article=ARTICLE,
89
+ allow_screenshot=args.allow_screenshot,
90
+ allow_flagging=args.allow_flagging,
91
+ live=args.live,
92
+ ).launch(
93
+ enable_queue=args.enable_queue,
94
+ server_port=args.port,
95
+ share=args.share,
96
+ )
97
+
98
+
99
+ if __name__ == '__main__':
100
+ main()
p2c/cog.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ predict: "predict.py:Predictor"
2
+ build:
3
+ python_version: "3.8"
4
+ system_packages:
5
+ - "libgl1-mesa-glx"
6
+ - "libglib2.0-0"
7
+ python_packages:
8
+ - "cmake==3.21.1"
9
+ - "torch==1.8.0"
10
+ - "torchvision==0.9.0"
11
+ - "numpy==1.19.2"
12
+ - "ipython==7.21.0"
13
+ - "opencv-python==4.3.0.38"
14
+ - "face-alignment==1.3.4"
15
+ - "tensorflow-gpu==2.5.0"
16
+ pre_install:
17
+ - pip install dlib
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
p2c/data_process.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ from utils import Preprocess
8
+
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--data_path', type=str, help='photo folder path')
12
+ parser.add_argument('--save_path', type=str, help='save folder path')
13
+
14
+ args = parser.parse_args()
15
+ os.makedirs(args.save_path, exist_ok=True)
16
+
17
+ pre = Preprocess()
18
+
19
+ for idx, img_name in enumerate(tqdm(os.listdir(args.data_path))):
20
+ img = cv2.cvtColor(cv2.imread(os.path.join(args.data_path, img_name)), cv2.COLOR_BGR2RGB)
21
+
22
+ # face alignment and segmentation
23
+ face_rgba = pre.process(img)
24
+ if face_rgba is not None:
25
+ # change background to white
26
+ face = face_rgba[:,:,:3].copy()
27
+ mask = face_rgba[:,:,3].copy()[:,:,np.newaxis]/255.
28
+ face_white_bg = (face*mask + (1-mask)*255).astype(np.uint8)
29
+
30
+ cv2.imwrite(os.path.join(args.save_path, str(idx).zfill(4)+'.png'), cv2.cvtColor(face_white_bg, cv2.COLOR_RGB2BGR))
p2c/dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+
3
+ from PIL import Image
4
+
5
+ import os
6
+ import os.path
7
+
8
+
9
+ def has_file_allowed_extension(filename, extensions):
10
+ """Checks if a file is an allowed extension.
11
+
12
+ Args:
13
+ filename (string): path to a file
14
+
15
+ Returns:
16
+ bool: True if the filename ends with a known image extension
17
+ """
18
+ filename_lower = filename.lower()
19
+ return any(filename_lower.endswith(ext) for ext in extensions)
20
+
21
+
22
+ def find_classes(dir):
23
+ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
24
+ classes.sort()
25
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
26
+ return classes, class_to_idx
27
+
28
+
29
+ def make_dataset(dir, extensions):
30
+ images = []
31
+ for root, _, fnames in sorted(os.walk(dir)):
32
+ for fname in sorted(fnames):
33
+ if has_file_allowed_extension(fname, extensions):
34
+ path = os.path.join(root, fname)
35
+ item = (path, 0)
36
+ images.append(item)
37
+
38
+ return images
39
+
40
+
41
+ class DatasetFolder(data.Dataset):
42
+ def __init__(self, root, loader, extensions, transform=None, target_transform=None):
43
+ # classes, class_to_idx = find_classes(root)
44
+ samples = make_dataset(root, extensions)
45
+ if len(samples) == 0:
46
+ raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
47
+ "Supported extensions are: " + ",".join(extensions)))
48
+
49
+ self.root = root
50
+ self.loader = loader
51
+ self.extensions = extensions
52
+ self.samples = samples
53
+
54
+ self.transform = transform
55
+ self.target_transform = target_transform
56
+
57
+ def __getitem__(self, index):
58
+ """
59
+ Args:
60
+ index (int): Index
61
+
62
+ Returns:
63
+ tuple: (sample, target) where target is class_index of the target class.
64
+ """
65
+ path, target = self.samples[index]
66
+ sample = self.loader(path)
67
+ if self.transform is not None:
68
+ sample = self.transform(sample)
69
+ if self.target_transform is not None:
70
+ target = self.target_transform(target)
71
+
72
+ return sample, target
73
+
74
+ def __len__(self):
75
+ return len(self.samples)
76
+
77
+ def __repr__(self):
78
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
79
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
80
+ fmt_str += ' Root Location: {}\n'.format(self.root)
81
+ tmp = ' Transforms (if any): '
82
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
83
+ tmp = ' Target Transforms (if any): '
84
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
85
+ return fmt_str
86
+
87
+
88
+ IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
89
+
90
+
91
+ def pil_loader(path):
92
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
93
+ with open(path, 'rb') as f:
94
+ img = Image.open(f)
95
+ return img.convert('RGB')
96
+
97
+
98
+ def default_loader(path):
99
+ return pil_loader(path)
100
+
101
+
102
+ class ImageFolder(DatasetFolder):
103
+ def __init__(self, root, transform=None, target_transform=None,
104
+ loader=default_loader):
105
+ super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
106
+ transform=transform,
107
+ target_transform=target_transform)
108
+ self.imgs = self.samples
p2c/dataset/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ β”œβ”€β”€ dataset
3
+ └── photo2cartoon
4
+ β”œβ”€β”€ trainA
5
+ β”œβ”€β”€ xxx.jpg
6
+ β”œβ”€β”€ yyy.png
7
+ └── ...
8
+ β”œβ”€β”€ trainB
9
+ β”œβ”€β”€ zzz.jpg
10
+ β”œβ”€β”€ www.png
11
+ └── ...
12
+ β”œβ”€β”€ testA
13
+ β”œβ”€β”€ aaa.jpg
14
+ β”œβ”€β”€ bbb.png
15
+ └── ...
16
+ └── testB
17
+ β”œβ”€β”€ ccc.jpg
18
+ β”œβ”€β”€ ddd.png
19
+ └── ...
20
+ ```
p2c/images/QRcode.jpg ADDED
p2c/images/data_process.jpg ADDED
p2c/images/photo_test.jpg ADDED
p2c/images/results.png ADDED
p2c/images/title.png ADDED
p2c/models/UGATIT_sadalin_hourglass.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import itertools
3
+ from dataset import ImageFolder
4
+ from torchvision import transforms
5
+ from torch.utils.data import DataLoader
6
+ from .networks import *
7
+ from utils import *
8
+ from glob import glob
9
+ from .face_features import FaceFeatures
10
+
11
+
12
+ class UgatitSadalinHourglass(object):
13
+ def __init__(self, args):
14
+ self.light = args.light
15
+
16
+ if self.light:
17
+ self.model_name = 'UGATIT_light'
18
+ else:
19
+ self.model_name = 'UGATIT'
20
+
21
+ self.result_dir = args.result_dir
22
+ self.dataset = args.dataset
23
+
24
+ self.iteration = args.iteration
25
+ self.decay_flag = args.decay_flag
26
+
27
+ self.batch_size = args.batch_size
28
+ self.print_freq = args.print_freq
29
+ self.save_freq = args.save_freq
30
+
31
+ self.lr = args.lr
32
+ self.ch = args.ch
33
+
34
+ """ Weight """
35
+ self.adv_weight = args.adv_weight
36
+ self.cycle_weight = args.cycle_weight
37
+ self.identity_weight = args.identity_weight
38
+ self.cam_weight = args.cam_weight
39
+ self.faceid_weight = args.faceid_weight
40
+
41
+ """ Discriminator """
42
+ self.n_dis = args.n_dis
43
+
44
+ self.img_size = args.img_size
45
+ self.img_ch = args.img_ch
46
+
47
+ self.device = f'cuda:{args.gpu_ids[0]}'
48
+ self.gpu_ids = args.gpu_ids
49
+ self.benchmark_flag = args.benchmark_flag
50
+ self.resume = args.resume
51
+ self.rho_clipper = args.rho_clipper
52
+ self.w_clipper = args.w_clipper
53
+ self.pretrained_weights = args.pretrained_weights
54
+
55
+ if torch.backends.cudnn.enabled and self.benchmark_flag:
56
+ print('set benchmark !')
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ print("##### Information #####")
60
+ print("# light : ", self.light)
61
+ print("# dataset : ", self.dataset)
62
+ print("# batch_size : ", self.batch_size)
63
+ print("# iteration per epoch : ", self.iteration)
64
+
65
+ print("##### Discriminator #####")
66
+ print("# discriminator layer : ", self.n_dis)
67
+
68
+ print()
69
+
70
+ print("##### Weight #####")
71
+ print("# adv_weight : ", self.adv_weight)
72
+ print("# cycle_weight : ", self.cycle_weight)
73
+ print("# faceid_weight : ", self.faceid_weight)
74
+ print("# identity_weight : ", self.identity_weight)
75
+ print("# cam_weight : ", self.cam_weight)
76
+ print("# rho_clipper: ", self.rho_clipper)
77
+ print("# w_clipper: ", self.w_clipper)
78
+
79
+ ##################################################################################
80
+ # Model
81
+ ##################################################################################
82
+
83
+ def build_model(self):
84
+ """ DataLoader """
85
+ train_transform = transforms.Compose([
86
+ transforms.RandomHorizontalFlip(),
87
+ transforms.Resize((self.img_size + 30, self.img_size+30)),
88
+ transforms.RandomCrop(self.img_size),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
91
+ ])
92
+ test_transform = transforms.Compose([
93
+ transforms.Resize((self.img_size, self.img_size)),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
96
+ ])
97
+ self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
98
+ self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
99
+ self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
100
+ self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
101
+
102
+ self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
103
+ self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
104
+ self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
105
+ self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
106
+
107
+ """ Define Generator, Discriminator """
108
+ self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
109
+ self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
110
+ self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
111
+ self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
112
+ self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
113
+ self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
114
+
115
+ self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device)
116
+
117
+ """ Define Loss """
118
+ self.L1_loss = nn.L1Loss().to(self.device)
119
+ self.MSE_loss = nn.MSELoss().to(self.device)
120
+ self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)
121
+
122
+ """ Trainer """
123
+ self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001)
124
+ self.D_optim = torch.optim.Adam(
125
+ itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()),
126
+ lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001
127
+ )
128
+
129
+ """ Define Rho clipper to constraint the value of rho in AdaLIN and LIN"""
130
+ self.Rho_clipper = RhoClipper(0, self.rho_clipper)
131
+ self.W_Clipper = WClipper(0, self.w_clipper)
132
+
133
+ def train(self):
134
+ self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
135
+
136
+ start_iter = 1
137
+ if self.resume:
138
+ model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
139
+ if not len(model_list) == 0:
140
+ model_list.sort()
141
+ start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
142
+ self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
143
+ print(" [*] Load SUCCESS")
144
+ if self.decay_flag and start_iter > (self.iteration // 2):
145
+ self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
146
+ self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
147
+
148
+ if self.pretrained_weights:
149
+ params = torch.load(self.pretrained_weights, map_location=self.device)
150
+ self.genA2B.load_state_dict(params['genA2B'])
151
+ self.genB2A.load_state_dict(params['genB2A'])
152
+ self.disGA.load_state_dict(params['disGA'])
153
+ self.disGB.load_state_dict(params['disGB'])
154
+ self.disLA.load_state_dict(params['disLA'])
155
+ self.disLB.load_state_dict(params['disLB'])
156
+ print(" [*] Load {} Success".format(self.pretrained_weights))
157
+
158
+ if len(self.gpu_ids) > 1:
159
+ self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids)
160
+ self.genB2A = nn.DataParallel(self.genB2A, device_ids=self.gpu_ids)
161
+ self.disGA = nn.DataParallel(self.disGA, device_ids=self.gpu_ids)
162
+ self.disGB = nn.DataParallel(self.disGB, device_ids=self.gpu_ids)
163
+ self.disLA = nn.DataParallel(self.disLA, device_ids=self.gpu_ids)
164
+ self.disLB = nn.DataParallel(self.disLB, device_ids=self.gpu_ids)
165
+
166
+ # training loop
167
+ print('training start !')
168
+ start_time = time.time()
169
+ for step in range(start_iter, self.iteration + 1):
170
+ if self.decay_flag and step > (self.iteration // 2):
171
+ self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
172
+ self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
173
+
174
+ try:
175
+ real_A, _ = trainA_iter.next()
176
+ except:
177
+ trainA_iter = iter(self.trainA_loader)
178
+ real_A, _ = trainA_iter.next()
179
+
180
+ try:
181
+ real_B, _ = trainB_iter.next()
182
+ except:
183
+ trainB_iter = iter(self.trainB_loader)
184
+ real_B, _ = trainB_iter.next()
185
+
186
+ real_A, real_B = real_A.to(self.device), real_B.to(self.device)
187
+
188
+ # Update D
189
+ self.D_optim.zero_grad()
190
+
191
+ fake_A2B, _, _ = self.genA2B(real_A)
192
+ fake_B2A, _, _ = self.genB2A(real_B)
193
+
194
+ real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
195
+ real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
196
+ real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
197
+ real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)
198
+
199
+ fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
200
+ fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
201
+ fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
202
+ fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
203
+
204
+ D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + \
205
+ self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
206
+
207
+ D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + \
208
+ self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
209
+
210
+ D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + \
211
+ self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
212
+
213
+ D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) +\
214
+ self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
215
+
216
+ D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + \
217
+ self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
218
+
219
+ D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + \
220
+ self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
221
+
222
+ D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + \
223
+ self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
224
+
225
+ D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) +\
226
+ self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))
227
+
228
+ D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
229
+ D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
230
+
231
+ Discriminator_loss = D_loss_A + D_loss_B
232
+ Discriminator_loss.backward()
233
+ self.D_optim.step()
234
+
235
+ # Update G
236
+ self.G_optim.zero_grad()
237
+
238
+ fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
239
+ fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
240
+
241
+ fake_A2B2A, _, _ = self.genB2A(fake_A2B)
242
+ fake_B2A2B, _, _ = self.genA2B(fake_B2A)
243
+
244
+ fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
245
+ fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)
246
+
247
+ fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
248
+ fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
249
+ fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
250
+ fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
251
+
252
+ G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
253
+ G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
254
+ G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
255
+ G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
256
+ G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
257
+ G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
258
+ G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
259
+ G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))
260
+
261
+ G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
262
+ G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)
263
+
264
+ G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
265
+ G_identity_loss_B = self.L1_loss(fake_B2B, real_B)
266
+
267
+ G_id_loss_A = self.facenet.cosine_distance(real_A, fake_A2B)
268
+ G_id_loss_B = self.facenet.cosine_distance(real_B, fake_B2A)
269
+ if len(self.gpu_ids) > 1:
270
+ G_id_loss_A = torch.mean(G_id_loss_A)
271
+ G_id_loss_B = torch.mean(G_id_loss_B)
272
+
273
+ G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + \
274
+ self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
275
+ G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + \
276
+ self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))
277
+
278
+ G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \
279
+ self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + \
280
+ self.cam_weight * G_cam_loss_A + self.faceid_weight * G_id_loss_A
281
+ G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \
282
+ self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + \
283
+ self.cam_weight * G_cam_loss_B + self.faceid_weight * G_id_loss_B
284
+
285
+ Generator_loss = G_loss_A + G_loss_B
286
+ Generator_loss.backward()
287
+ self.G_optim.step()
288
+
289
+ # clip parameter of Soft-AdaLIN and LIN, applied after optimizer step
290
+ self.genA2B.apply(self.Rho_clipper)
291
+ self.genB2A.apply(self.Rho_clipper)
292
+
293
+ self.genA2B.apply(self.W_Clipper)
294
+ self.genB2A.apply(self.W_Clipper)
295
+
296
+ if step % 10 == 0:
297
+ print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss))
298
+ if step % self.print_freq == 0:
299
+ train_sample_num = 5
300
+ test_sample_num = 5
301
+ A2B = np.zeros((self.img_size * 7, 0, 3))
302
+ B2A = np.zeros((self.img_size * 7, 0, 3))
303
+
304
+ self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
305
+ with torch.no_grad():
306
+ for _ in range(train_sample_num):
307
+ try:
308
+ real_A, _ = trainA_iter.next()
309
+ except:
310
+ trainA_iter = iter(self.trainA_loader)
311
+ real_A, _ = trainA_iter.next()
312
+
313
+ try:
314
+ real_B, _ = trainB_iter.next()
315
+ except:
316
+ trainB_iter = iter(self.trainB_loader)
317
+ real_B, _ = trainB_iter.next()
318
+ real_A, real_B = real_A.to(self.device), real_B.to(self.device)
319
+
320
+ fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
321
+ fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
322
+
323
+ fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
324
+ fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
325
+
326
+ fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
327
+ fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
328
+
329
+ A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
330
+ cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
331
+ RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
332
+ cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
333
+ RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
334
+ cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
335
+ RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
336
+
337
+ B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
338
+ cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
339
+ RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
340
+ cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
341
+ RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
342
+ cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
343
+ RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
344
+
345
+ for _ in range(test_sample_num):
346
+ try:
347
+ real_A, _ = testA_iter.next()
348
+ except:
349
+ testA_iter = iter(self.testA_loader)
350
+ real_A, _ = testA_iter.next()
351
+
352
+ try:
353
+ real_B, _ = testB_iter.next()
354
+ except:
355
+ testB_iter = iter(self.testB_loader)
356
+ real_B, _ = testB_iter.next()
357
+ real_A, real_B = real_A.to(self.device), real_B.to(self.device)
358
+
359
+ fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
360
+ fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
361
+
362
+ fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
363
+ fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
364
+
365
+ fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
366
+ fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
367
+
368
+ A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
369
+ cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
370
+ RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
371
+ cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
372
+ RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
373
+ cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
374
+ RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
375
+
376
+ B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
377
+ cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
378
+ RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
379
+ cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
380
+ RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
381
+ cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
382
+ RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
383
+
384
+ cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
385
+ cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)
386
+ self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
387
+
388
+ if step % self.save_freq == 0:
389
+ self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)
390
+
391
+ if step % 1000 == 0:
392
+ params = {}
393
+
394
+ if len(self.gpu_ids) > 1:
395
+ params['genA2B'] = self.genA2B.module.state_dict()
396
+ params['genB2A'] = self.genB2A.module.state_dict()
397
+ params['disGA'] = self.disGA.module.state_dict()
398
+ params['disGB'] = self.disGB.module.state_dict()
399
+ params['disLA'] = self.disLA.module.state_dict()
400
+ params['disLB'] = self.disLB.module.state_dict()
401
+
402
+ else:
403
+ params['genA2B'] = self.genA2B.state_dict()
404
+ params['genB2A'] = self.genB2A.state_dict()
405
+ params['disGA'] = self.disGA.state_dict()
406
+ params['disGB'] = self.disGB.state_dict()
407
+ params['disLA'] = self.disLA.state_dict()
408
+ params['disLB'] = self.disLB.state_dict()
409
+ torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
410
+
411
+ def save(self, dir, step):
412
+ params = {}
413
+
414
+ if len(self.gpu_ids) > 1:
415
+ params['genA2B'] = self.genA2B.module.state_dict()
416
+ params['genB2A'] = self.genB2A.module.state_dict()
417
+ params['disGA'] = self.disGA.module.state_dict()
418
+ params['disGB'] = self.disGB.module.state_dict()
419
+ params['disLA'] = self.disLA.module.state_dict()
420
+ params['disLB'] = self.disLB.module.state_dict()
421
+
422
+ else:
423
+ params['genA2B'] = self.genA2B.state_dict()
424
+ params['genB2A'] = self.genB2A.state_dict()
425
+ params['disGA'] = self.disGA.state_dict()
426
+ params['disGB'] = self.disGB.state_dict()
427
+ params['disLA'] = self.disLA.state_dict()
428
+ params['disLB'] = self.disLB.state_dict()
429
+ torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
430
+
431
+ def load(self, dir, step):
432
+ params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
433
+ self.genA2B.load_state_dict(params['genA2B'])
434
+ self.genB2A.load_state_dict(params['genB2A'])
435
+ self.disGA.load_state_dict(params['disGA'])
436
+ self.disGB.load_state_dict(params['disGB'])
437
+ self.disLA.load_state_dict(params['disLA'])
438
+ self.disLB.load_state_dict(params['disLB'])
439
+
440
+ def test(self):
441
+ model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
442
+ if not len(model_list) == 0:
443
+ model_list.sort()
444
+ iter = int(model_list[-1].split('_')[-1].split('.')[0])
445
+ self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
446
+ print(" [*] Load SUCCESS")
447
+ else:
448
+ print(" [*] Load FAILURE")
449
+ return
450
+
451
+ self.genA2B.eval(), self.genB2A.eval()
452
+ with torch.no_grad():
453
+ for n, (real_A, _) in enumerate(self.testA_loader):
454
+ real_A = real_A.to(self.device)
455
+
456
+ fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
457
+
458
+ fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
459
+
460
+ fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
461
+
462
+ A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
463
+ cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
464
+ RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
465
+ cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
466
+ RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
467
+ cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
468
+ RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)
469
+
470
+ cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)
471
+
472
+ for n, (real_B, _) in enumerate(self.testB_loader):
473
+ real_B = real_B.to(self.device)
474
+
475
+ fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
476
+
477
+ fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
478
+
479
+ fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
480
+
481
+ B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
482
+ cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
483
+ RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
484
+ cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
485
+ RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
486
+ cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
487
+ RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)
488
+
489
+ cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
p2c/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .networks import ResnetGenerator
2
+ from .UGATIT_sadalin_hourglass import UgatitSadalinHourglass
3
+
p2c/models/face_features.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .mobilefacenet import MobileFaceNet
4
+
5
+
6
+ class FaceFeatures(object):
7
+ def __init__(self, weights_path, device):
8
+ self.device = device
9
+ self.model = MobileFaceNet(512).to(device)
10
+ self.model.load_state_dict(torch.load(weights_path))
11
+ self.model.eval()
12
+
13
+ def infer(self, batch_tensor):
14
+ # crop face
15
+ h, w = batch_tensor.shape[2:]
16
+ top = int(h / 2.1 * (0.8 - 0.33))
17
+ bottom = int(h - (h / 2.1 * 0.3))
18
+ size = bottom - top
19
+ left = int(w / 2 - size / 2)
20
+ right = left + size
21
+ batch_tensor = batch_tensor[:, :, top: bottom, left: right]
22
+
23
+ batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)
24
+
25
+ features = self.model(batch_tensor)
26
+ return features
27
+
28
+ def cosine_distance(self, batch_tensor1, batch_tensor2):
29
+ feature1 = self.infer(batch_tensor1)
30
+ feature2 = self.infer(batch_tensor2)
31
+ return 1 - torch.cosine_similarity(feature1, feature2)
p2c/models/mobilefacenet.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, \
2
+ MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
3
+ import torch
4
+ from collections import namedtuple
5
+
6
+
7
+ ################################## Original Arcface Model #############################################################
8
+
9
+ class Flatten(Module):
10
+ def forward(self, input):
11
+ return input.view(input.size(0), -1)
12
+
13
+
14
+ def l2_norm(input, axis=1):
15
+ norm = torch.norm(input, 2, axis, True)
16
+ output = torch.div(input, norm)
17
+ return output
18
+
19
+
20
+ class SEModule(Module):
21
+ def __init__(self, channels, reduction):
22
+ super(SEModule, self).__init__()
23
+ self.avg_pool = AdaptiveAvgPool2d(1)
24
+ self.fc1 = Conv2d(
25
+ channels, channels // reduction, kernel_size=1, padding=0, bias=False)
26
+ self.relu = ReLU(inplace=True)
27
+ self.fc2 = Conv2d(
28
+ channels // reduction, channels, kernel_size=1, padding=0, bias=False)
29
+ self.sigmoid = Sigmoid()
30
+
31
+ def forward(self, x):
32
+ module_input = x
33
+ x = self.avg_pool(x)
34
+ x = self.fc1(x)
35
+ x = self.relu(x)
36
+ x = self.fc2(x)
37
+ x = self.sigmoid(x)
38
+ return module_input * x
39
+
40
+
41
+ class bottleneck_IR(Module):
42
+ def __init__(self, in_channel, depth, stride):
43
+ super(bottleneck_IR, self).__init__()
44
+ if in_channel == depth:
45
+ self.shortcut_layer = MaxPool2d(1, stride)
46
+ else:
47
+ self.shortcut_layer = Sequential(
48
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
49
+ self.res_layer = Sequential(
50
+ BatchNorm2d(in_channel),
51
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
52
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
53
+
54
+ def forward(self, x):
55
+ shortcut = self.shortcut_layer(x)
56
+ res = self.res_layer(x)
57
+ return res + shortcut
58
+
59
+
60
+ class bottleneck_IR_SE(Module):
61
+ def __init__(self, in_channel, depth, stride):
62
+ super(bottleneck_IR_SE, self).__init__()
63
+ if in_channel == depth:
64
+ self.shortcut_layer = MaxPool2d(1, stride)
65
+ else:
66
+ self.shortcut_layer = Sequential(
67
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
68
+ BatchNorm2d(depth))
69
+ self.res_layer = Sequential(
70
+ BatchNorm2d(in_channel),
71
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
72
+ PReLU(depth),
73
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
74
+ BatchNorm2d(depth),
75
+ SEModule(depth, 16)
76
+ )
77
+
78
+ def forward(self, x):
79
+ shortcut = self.shortcut_layer(x)
80
+ res = self.res_layer(x)
81
+ return res + shortcut
82
+
83
+
84
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
85
+ '''A named tuple describing a ResNet block.'''
86
+
87
+
88
+ def get_block(in_channel, depth, num_units, stride=2):
89
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
90
+
91
+
92
+ def get_blocks(num_layers):
93
+ if num_layers == 50:
94
+ blocks = [
95
+ get_block(in_channel=64, depth=64, num_units=3),
96
+ get_block(in_channel=64, depth=128, num_units=4),
97
+ get_block(in_channel=128, depth=256, num_units=14),
98
+ get_block(in_channel=256, depth=512, num_units=3)
99
+ ]
100
+ elif num_layers == 100:
101
+ blocks = [
102
+ get_block(in_channel=64, depth=64, num_units=3),
103
+ get_block(in_channel=64, depth=128, num_units=13),
104
+ get_block(in_channel=128, depth=256, num_units=30),
105
+ get_block(in_channel=256, depth=512, num_units=3)
106
+ ]
107
+ elif num_layers == 152:
108
+ blocks = [
109
+ get_block(in_channel=64, depth=64, num_units=3),
110
+ get_block(in_channel=64, depth=128, num_units=8),
111
+ get_block(in_channel=128, depth=256, num_units=36),
112
+ get_block(in_channel=256, depth=512, num_units=3)
113
+ ]
114
+ return blocks
115
+
116
+
117
+ class Backbone(Module):
118
+ def __init__(self, num_layers, drop_ratio, mode='ir'):
119
+ super(Backbone, self).__init__()
120
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
121
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
122
+ blocks = get_blocks(num_layers)
123
+ if mode == 'ir':
124
+ unit_module = bottleneck_IR
125
+ elif mode == 'ir_se':
126
+ unit_module = bottleneck_IR_SE
127
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
128
+ BatchNorm2d(64),
129
+ PReLU(64))
130
+ self.output_layer = Sequential(BatchNorm2d(512),
131
+ Dropout(drop_ratio),
132
+ Flatten(),
133
+ Linear(512 * 7 * 7, 512),
134
+ BatchNorm1d(512))
135
+ modules = []
136
+ for block in blocks:
137
+ for bottleneck in block:
138
+ modules.append(
139
+ unit_module(bottleneck.in_channel,
140
+ bottleneck.depth,
141
+ bottleneck.stride))
142
+ self.body = Sequential(*modules)
143
+
144
+ def forward(self, x):
145
+ x = self.input_layer(x)
146
+ x = self.body(x)
147
+ x = self.output_layer(x)
148
+ return l2_norm(x)
149
+
150
+
151
+ ################################## MobileFaceNet #############################################################
152
+
153
+ class Conv_block(Module):
154
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
155
+ super(Conv_block, self).__init__()
156
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
157
+ bias=False)
158
+ self.bn = BatchNorm2d(out_c)
159
+ self.prelu = PReLU(out_c)
160
+
161
+ def forward(self, x):
162
+ x = self.conv(x)
163
+ x = self.bn(x)
164
+ x = self.prelu(x)
165
+ return x
166
+
167
+
168
+ class Linear_block(Module):
169
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
170
+ super(Linear_block, self).__init__()
171
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
172
+ bias=False)
173
+ self.bn = BatchNorm2d(out_c)
174
+
175
+ def forward(self, x):
176
+ x = self.conv(x)
177
+ x = self.bn(x)
178
+ return x
179
+
180
+
181
+ class Depth_Wise(Module):
182
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
183
+ super(Depth_Wise, self).__init__()
184
+ self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
185
+ self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
186
+ self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
187
+ self.residual = residual
188
+
189
+ def forward(self, x):
190
+ if self.residual:
191
+ short_cut = x
192
+ x = self.conv(x)
193
+ x = self.conv_dw(x)
194
+ x = self.project(x)
195
+ if self.residual:
196
+ output = short_cut + x
197
+ else:
198
+ output = x
199
+ return output
200
+
201
+
202
+ class Residual(Module):
203
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
204
+ super(Residual, self).__init__()
205
+ modules = []
206
+ for _ in range(num_block):
207
+ modules.append(
208
+ Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
209
+ self.model = Sequential(*modules)
210
+
211
+ def forward(self, x):
212
+ return self.model(x)
213
+
214
+
215
+ class MobileFaceNet(Module):
216
+ def __init__(self, embedding_size):
217
+ super(MobileFaceNet, self).__init__()
218
+ self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
219
+ self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
220
+ self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
221
+ self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
222
+ self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
223
+ self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
224
+ self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
225
+ self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
226
+ self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
227
+ self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
228
+ self.conv_6_flatten = Flatten()
229
+ self.linear = Linear(512, embedding_size, bias=False)
230
+ self.bn = BatchNorm1d(embedding_size)
231
+
232
+ def forward(self, x):
233
+ out = self.conv1(x)
234
+
235
+ out = self.conv2_dw(out)
236
+
237
+ out = self.conv_23(out)
238
+
239
+ out = self.conv_3(out)
240
+
241
+ out = self.conv_34(out)
242
+
243
+ out = self.conv_4(out)
244
+
245
+ out = self.conv_45(out)
246
+
247
+ out = self.conv_5(out)
248
+
249
+ out = self.conv_6_sep(out)
250
+
251
+ out = self.conv_6_dw(out)
252
+
253
+ out = self.conv_6_flatten(out)
254
+
255
+ out = self.linear(out)
256
+
257
+ out = self.bn(out)
258
+ return l2_norm(out)
p2c/models/model_mobilefacenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f3bbd745247b32641724bf6d7964df7fd94ea5a098fe16d692b412fe44cd59b
3
+ size 4938364
p2c/models/networks.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.parameter import Parameter
5
+
6
+
7
+ class ResnetGenerator(nn.Module):
8
+ def __init__(self, ngf=64, img_size=256, light=False):
9
+ super(ResnetGenerator, self).__init__()
10
+ self.light = light
11
+
12
+ self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
13
+ nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
14
+ nn.InstanceNorm2d(ngf),
15
+ nn.ReLU(True))
16
+
17
+ self.HourGlass1 = HourGlass(ngf, ngf)
18
+ self.HourGlass2 = HourGlass(ngf, ngf)
19
+
20
+ # Down-Sampling
21
+ self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
22
+ nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
23
+ nn.InstanceNorm2d(ngf * 2),
24
+ nn.ReLU(True))
25
+
26
+ self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
27
+ nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
28
+ nn.InstanceNorm2d(ngf*4),
29
+ nn.ReLU(True))
30
+
31
+ # Encoder Bottleneck
32
+ self.EncodeBlock1 = ResnetBlock(ngf*4)
33
+ self.EncodeBlock2 = ResnetBlock(ngf*4)
34
+ self.EncodeBlock3 = ResnetBlock(ngf*4)
35
+ self.EncodeBlock4 = ResnetBlock(ngf*4)
36
+
37
+ # Class Activation Map
38
+ self.gap_fc = nn.Linear(ngf*4, 1)
39
+ self.gmp_fc = nn.Linear(ngf*4, 1)
40
+ self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
41
+ self.relu = nn.ReLU(True)
42
+
43
+ # Gamma, Beta block
44
+ if self.light:
45
+ self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
46
+ nn.ReLU(True),
47
+ nn.Linear(ngf*4, ngf*4),
48
+ nn.ReLU(True))
49
+ else:
50
+ self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
51
+ nn.ReLU(True),
52
+ nn.Linear(ngf*4, ngf*4),
53
+ nn.ReLU(True))
54
+
55
+ # Decoder Bottleneck
56
+ self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
57
+ self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
58
+ self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
59
+ self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
60
+
61
+ # Up-Sampling
62
+ self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
63
+ nn.ReflectionPad2d(1),
64
+ nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
65
+ LIN(ngf*2),
66
+ nn.ReLU(True))
67
+
68
+ self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
69
+ nn.ReflectionPad2d(1),
70
+ nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
71
+ LIN(ngf),
72
+ nn.ReLU(True))
73
+
74
+ self.HourGlass3 = HourGlass(ngf, ngf)
75
+ self.HourGlass4 = HourGlass(ngf, ngf, False)
76
+
77
+ self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
78
+ nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
79
+ nn.Tanh())
80
+
81
+ def forward(self, x):
82
+ x = self.ConvBlock1(x)
83
+ x = self.HourGlass1(x)
84
+ x = self.HourGlass2(x)
85
+
86
+ x = self.DownBlock1(x)
87
+ x = self.DownBlock2(x)
88
+
89
+ x = self.EncodeBlock1(x)
90
+ content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
91
+ x = self.EncodeBlock2(x)
92
+ content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
93
+ x = self.EncodeBlock3(x)
94
+ content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
95
+ x = self.EncodeBlock4(x)
96
+ content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
97
+
98
+ gap = F.adaptive_avg_pool2d(x, 1)
99
+ gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
100
+ gap_weight = list(self.gap_fc.parameters())[0]
101
+ gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
102
+
103
+ gmp = F.adaptive_max_pool2d(x, 1)
104
+ gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
105
+ gmp_weight = list(self.gmp_fc.parameters())[0]
106
+ gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
107
+
108
+ cam_logit = torch.cat([gap_logit, gmp_logit], 1)
109
+ x = torch.cat([gap, gmp], 1)
110
+ x = self.relu(self.conv1x1(x))
111
+
112
+ heatmap = torch.sum(x, dim=1, keepdim=True)
113
+
114
+ if self.light:
115
+ x_ = F.adaptive_avg_pool2d(x, 1)
116
+ style_features = self.FC(x_.view(x_.shape[0], -1))
117
+ else:
118
+ style_features = self.FC(x.view(x.shape[0], -1))
119
+
120
+ x = self.DecodeBlock1(x, content_features4, style_features)
121
+ x = self.DecodeBlock2(x, content_features3, style_features)
122
+ x = self.DecodeBlock3(x, content_features2, style_features)
123
+ x = self.DecodeBlock4(x, content_features1, style_features)
124
+
125
+ x = self.UpBlock1(x)
126
+ x = self.UpBlock2(x)
127
+
128
+ x = self.HourGlass3(x)
129
+ x = self.HourGlass4(x)
130
+ out = self.ConvBlock2(x)
131
+
132
+ return out, cam_logit, heatmap
133
+
134
+
135
+ class ConvBlock(nn.Module):
136
+ def __init__(self, dim_in, dim_out):
137
+ super(ConvBlock, self).__init__()
138
+ self.dim_out = dim_out
139
+
140
+ self.ConvBlock1 = nn.Sequential(nn.InstanceNorm2d(dim_in),
141
+ nn.ReLU(True),
142
+ nn.ReflectionPad2d(1),
143
+ nn.Conv2d(dim_in, dim_out//2, kernel_size=3, stride=1, bias=False))
144
+
145
+ self.ConvBlock2 = nn.Sequential(nn.InstanceNorm2d(dim_out//2),
146
+ nn.ReLU(True),
147
+ nn.ReflectionPad2d(1),
148
+ nn.Conv2d(dim_out//2, dim_out//4, kernel_size=3, stride=1, bias=False))
149
+
150
+ self.ConvBlock3 = nn.Sequential(nn.InstanceNorm2d(dim_out//4),
151
+ nn.ReLU(True),
152
+ nn.ReflectionPad2d(1),
153
+ nn.Conv2d(dim_out//4, dim_out//4, kernel_size=3, stride=1, bias=False))
154
+
155
+ self.ConvBlock4 = nn.Sequential(nn.InstanceNorm2d(dim_in),
156
+ nn.ReLU(True),
157
+ nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False))
158
+
159
+ def forward(self, x):
160
+ residual = x
161
+
162
+ x1 = self.ConvBlock1(x)
163
+ x2 = self.ConvBlock2(x1)
164
+ x3 = self.ConvBlock3(x2)
165
+ out = torch.cat((x1, x2, x3), 1)
166
+
167
+ if residual.size(1) != self.dim_out:
168
+ residual = self.ConvBlock4(residual)
169
+
170
+ return residual + out
171
+
172
+
173
+ class HourGlass(nn.Module):
174
+ def __init__(self, dim_in, dim_out, use_res=True):
175
+ super(HourGlass, self).__init__()
176
+ self.use_res = use_res
177
+
178
+ self.HG = nn.Sequential(HourGlassBlock(dim_in, dim_out),
179
+ ConvBlock(dim_out, dim_out),
180
+ nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1, bias=False),
181
+ nn.InstanceNorm2d(dim_out),
182
+ nn.ReLU(True))
183
+
184
+ self.Conv1 = nn.Conv2d(dim_out, 3, kernel_size=1, stride=1)
185
+
186
+ if self.use_res:
187
+ self.Conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1)
188
+ self.Conv3 = nn.Conv2d(3, dim_out, kernel_size=1, stride=1)
189
+
190
+ def forward(self, x):
191
+ ll = self.HG(x)
192
+ tmp_out = self.Conv1(ll)
193
+
194
+ if self.use_res:
195
+ ll = self.Conv2(ll)
196
+ tmp_out_ = self.Conv3(tmp_out)
197
+ return x + ll + tmp_out_
198
+
199
+ else:
200
+ return tmp_out
201
+
202
+
203
+ class HourGlassBlock(nn.Module):
204
+ def __init__(self, dim_in, dim_out):
205
+ super(HourGlassBlock, self).__init__()
206
+
207
+ self.ConvBlock1_1 = ConvBlock(dim_in, dim_out)
208
+ self.ConvBlock1_2 = ConvBlock(dim_out, dim_out)
209
+ self.ConvBlock2_1 = ConvBlock(dim_out, dim_out)
210
+ self.ConvBlock2_2 = ConvBlock(dim_out, dim_out)
211
+ self.ConvBlock3_1 = ConvBlock(dim_out, dim_out)
212
+ self.ConvBlock3_2 = ConvBlock(dim_out, dim_out)
213
+ self.ConvBlock4_1 = ConvBlock(dim_out, dim_out)
214
+ self.ConvBlock4_2 = ConvBlock(dim_out, dim_out)
215
+
216
+ self.ConvBlock5 = ConvBlock(dim_out, dim_out)
217
+
218
+ self.ConvBlock6 = ConvBlock(dim_out, dim_out)
219
+ self.ConvBlock7 = ConvBlock(dim_out, dim_out)
220
+ self.ConvBlock8 = ConvBlock(dim_out, dim_out)
221
+ self.ConvBlock9 = ConvBlock(dim_out, dim_out)
222
+
223
+ def forward(self, x):
224
+ skip1 = self.ConvBlock1_1(x)
225
+ down1 = F.avg_pool2d(x, 2)
226
+ down1 = self.ConvBlock1_2(down1)
227
+
228
+ skip2 = self.ConvBlock2_1(down1)
229
+ down2 = F.avg_pool2d(down1, 2)
230
+ down2 = self.ConvBlock2_2(down2)
231
+
232
+ skip3 = self.ConvBlock3_1(down2)
233
+ down3 = F.avg_pool2d(down2, 2)
234
+ down3 = self.ConvBlock3_2(down3)
235
+
236
+ skip4 = self.ConvBlock4_1(down3)
237
+ down4 = F.avg_pool2d(down3, 2)
238
+ down4 = self.ConvBlock4_2(down4)
239
+
240
+ center = self.ConvBlock5(down4)
241
+
242
+ up4 = self.ConvBlock6(center)
243
+ up4 = F.upsample(up4, scale_factor=2)
244
+ up4 = skip4 + up4
245
+
246
+ up3 = self.ConvBlock7(up4)
247
+ up3 = F.upsample(up3, scale_factor=2)
248
+ up3 = skip3 + up3
249
+
250
+ up2 = self.ConvBlock8(up3)
251
+ up2 = F.upsample(up2, scale_factor=2)
252
+ up2 = skip2 + up2
253
+
254
+ up1 = self.ConvBlock9(up2)
255
+ up1 = F.upsample(up1, scale_factor=2)
256
+ up1 = skip1 + up1
257
+
258
+ return up1
259
+
260
+
261
+ class ResnetBlock(nn.Module):
262
+ def __init__(self, dim, use_bias=False):
263
+ super(ResnetBlock, self).__init__()
264
+ conv_block = []
265
+ conv_block += [nn.ReflectionPad2d(1),
266
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
267
+ nn.InstanceNorm2d(dim),
268
+ nn.ReLU(True)]
269
+
270
+ conv_block += [nn.ReflectionPad2d(1),
271
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
272
+ nn.InstanceNorm2d(dim)]
273
+
274
+ self.conv_block = nn.Sequential(*conv_block)
275
+
276
+ def forward(self, x):
277
+ out = x + self.conv_block(x)
278
+ return out
279
+
280
+
281
+ class ResnetSoftAdaLINBlock(nn.Module):
282
+ def __init__(self, dim, use_bias=False):
283
+ super(ResnetSoftAdaLINBlock, self).__init__()
284
+ self.pad1 = nn.ReflectionPad2d(1)
285
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
286
+ self.norm1 = SoftAdaLIN(dim)
287
+ self.relu1 = nn.ReLU(True)
288
+
289
+ self.pad2 = nn.ReflectionPad2d(1)
290
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
291
+ self.norm2 = SoftAdaLIN(dim)
292
+
293
+ def forward(self, x, content_features, style_features):
294
+ out = self.pad1(x)
295
+ out = self.conv1(out)
296
+ out = self.norm1(out, content_features, style_features)
297
+ out = self.relu1(out)
298
+
299
+ out = self.pad2(out)
300
+ out = self.conv2(out)
301
+ out = self.norm2(out, content_features, style_features)
302
+ return out + x
303
+
304
+
305
+ class ResnetAdaLINBlock(nn.Module):
306
+ def __init__(self, dim, use_bias=False):
307
+ super(ResnetAdaLINBlock, self).__init__()
308
+ self.pad1 = nn.ReflectionPad2d(1)
309
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
310
+ self.norm1 = adaLIN(dim)
311
+ self.relu1 = nn.ReLU(True)
312
+
313
+ self.pad2 = nn.ReflectionPad2d(1)
314
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
315
+ self.norm2 = adaLIN(dim)
316
+
317
+ def forward(self, x, gamma, beta):
318
+ out = self.pad1(x)
319
+ out = self.conv1(out)
320
+ out = self.norm1(out, gamma, beta)
321
+ out = self.relu1(out)
322
+ out = self.pad2(out)
323
+ out = self.conv2(out)
324
+ out = self.norm2(out, gamma, beta)
325
+
326
+ return out + x
327
+
328
+
329
+ class SoftAdaLIN(nn.Module):
330
+ def __init__(self, num_features, eps=1e-5):
331
+ super(SoftAdaLIN, self).__init__()
332
+ self.norm = adaLIN(num_features, eps)
333
+
334
+ self.w_gamma = Parameter(torch.zeros(1, num_features))
335
+ self.w_beta = Parameter(torch.zeros(1, num_features))
336
+
337
+ self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features),
338
+ nn.ReLU(True),
339
+ nn.Linear(num_features, num_features))
340
+ self.c_beta = nn.Sequential(nn.Linear(num_features, num_features),
341
+ nn.ReLU(True),
342
+ nn.Linear(num_features, num_features))
343
+ self.s_gamma = nn.Linear(num_features, num_features)
344
+ self.s_beta = nn.Linear(num_features, num_features)
345
+
346
+ def forward(self, x, content_features, style_features):
347
+ content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features)
348
+ style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features)
349
+
350
+ w_gamma, w_beta = self.w_gamma.expand(x.shape[0], -1), self.w_beta.expand(x.shape[0], -1)
351
+ soft_gamma = (1. - w_gamma) * style_gamma + w_gamma * content_gamma
352
+ soft_beta = (1. - w_beta) * style_beta + w_beta * content_beta
353
+
354
+ out = self.norm(x, soft_gamma, soft_beta)
355
+ return out
356
+
357
+
358
+ class adaLIN(nn.Module):
359
+ def __init__(self, num_features, eps=1e-5):
360
+ super(adaLIN, self).__init__()
361
+ self.eps = eps
362
+ self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
363
+ self.rho.data.fill_(0.9)
364
+
365
+ def forward(self, input, gamma, beta):
366
+ in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
367
+ out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
368
+ ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
369
+ out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
370
+ out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
371
+ out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
372
+
373
+ return out
374
+
375
+
376
+ class LIN(nn.Module):
377
+ def __init__(self, num_features, eps=1e-5):
378
+ super(LIN, self).__init__()
379
+ self.eps = eps
380
+ self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
381
+ self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
382
+ self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
383
+ self.rho.data.fill_(0.0)
384
+ self.gamma.data.fill_(1.0)
385
+ self.beta.data.fill_(0.0)
386
+
387
+ def forward(self, input):
388
+ in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
389
+ out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
390
+ ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
391
+ out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
392
+ out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
393
+ out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
394
+
395
+ return out
396
+
397
+
398
+ class Discriminator(nn.Module):
399
+ def __init__(self, input_nc, ndf=64, n_layers=5):
400
+ super(Discriminator, self).__init__()
401
+ model = [nn.ReflectionPad2d(1),
402
+ nn.utils.spectral_norm(
403
+ nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
404
+ nn.LeakyReLU(0.2, True)]
405
+
406
+ for i in range(1, n_layers - 2):
407
+ mult = 2 ** (i - 1)
408
+ model += [nn.ReflectionPad2d(1),
409
+ nn.utils.spectral_norm(
410
+ nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
411
+ nn.LeakyReLU(0.2, True)]
412
+
413
+ mult = 2 ** (n_layers - 2 - 1)
414
+ model += [nn.ReflectionPad2d(1),
415
+ nn.utils.spectral_norm(
416
+ nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
417
+ nn.LeakyReLU(0.2, True)]
418
+
419
+ # Class Activation Map
420
+ mult = 2 ** (n_layers - 2)
421
+ self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
422
+ self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
423
+ self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
424
+ self.leaky_relu = nn.LeakyReLU(0.2, True)
425
+
426
+ self.pad = nn.ReflectionPad2d(1)
427
+ self.conv = nn.utils.spectral_norm(
428
+ nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
429
+
430
+ self.model = nn.Sequential(*model)
431
+
432
+ def forward(self, input):
433
+ x = self.model(input)
434
+
435
+ gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
436
+ gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
437
+ gap_weight = list(self.gap_fc.parameters())[0]
438
+ gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
439
+
440
+ gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
441
+ gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
442
+ gmp_weight = list(self.gmp_fc.parameters())[0]
443
+ gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
444
+
445
+ cam_logit = torch.cat([gap_logit, gmp_logit], 1)
446
+ x = torch.cat([gap, gmp], 1)
447
+ x = self.leaky_relu(self.conv1x1(x))
448
+
449
+ heatmap = torch.sum(x, dim=1, keepdim=True)
450
+
451
+ x = self.pad(x)
452
+ out = self.conv(x)
453
+
454
+ return out, cam_logit, heatmap
455
+
456
+
457
+ class RhoClipper(object):
458
+ def __init__(self, min, max):
459
+ self.clip_min = min
460
+ self.clip_max = max
461
+ assert min < max
462
+
463
+ def __call__(self, module):
464
+ if hasattr(module, 'rho'):
465
+ w = module.rho.data
466
+ w = w.clamp(self.clip_min, self.clip_max)
467
+ module.rho.data = w
468
+
469
+
470
+ class WClipper(object):
471
+ def __init__(self, min, max):
472
+ self.clip_min = min
473
+ self.clip_max = max
474
+ assert min < max
475
+
476
+ def __call__(self, module):
477
+ if hasattr(module, 'w_gamma'):
478
+ w = module.w_gamma.data
479
+ w = w.clamp(self.clip_min, self.clip_max)
480
+ module.w_gamma.data = w
481
+
482
+ if hasattr(module, 'w_beta'):
483
+ w = module.w_beta.data
484
+ w = w.clamp(self.clip_min, self.clip_max)
485
+ module.w_beta.data = w
p2c/models/photo2cartoon_weights.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542914cb8580cb733c7e914d22cc24ddabbbb207516d74ffc793f2a1b6c3eeb3
3
+ size 15290506
p2c/models/photo2cartoon_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e08c84ea4c62251c6157dbf1d3ef44d2549d6aa8c9ee72ec9e4b3089ce5d5f0f
3
+ size 144306956
p2c/predict.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cog
2
+ import cv2
3
+ import tempfile
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from pathlib import Path
8
+ from utils import Preprocess
9
+ from models import ResnetGenerator
10
+
11
+
12
+ class Predictor(cog.Predictor):
13
+ def setup(self):
14
+ pass
15
+
16
+ @cog.input("photo", type=Path, help="portrait photo (size < 1M)")
17
+ def predict(self, photo):
18
+ img = cv2.cvtColor(cv2.imread(str(photo)), cv2.COLOR_BGR2RGB)
19
+ out_path = gen_cartoon(img)
20
+ return out_path
21
+
22
+
23
+ def gen_cartoon(img):
24
+ pre = Preprocess()
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ net = ResnetGenerator(ngf=32, img_size=256, light=True).to(device)
27
+
28
+ assert os.path.exists(
29
+ './models/photo2cartoon_weights.pt'), "[Step1: load weights] Can not find 'photo2cartoon_weights.pt' in folder 'models!!!'"
30
+ params = torch.load('./models/photo2cartoon_weights.pt', map_location=device)
31
+ net.load_state_dict(params['genA2B'])
32
+
33
+ # face alignment and segmentation
34
+ face_rgba = pre.process(img)
35
+ if face_rgba is None:
36
+ return None
37
+
38
+ face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
39
+ face = face_rgba[:, :, :3].copy()
40
+ mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
41
+ face = (face * mask + (1 - mask) * 255) / 127.5 - 1
42
+
43
+ face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
44
+ face = torch.from_numpy(face).to(device)
45
+
46
+ # inference
47
+ with torch.no_grad():
48
+ cartoon = net(face)[0][0]
49
+
50
+ # post-process
51
+ cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
52
+ cartoon = (cartoon + 1) * 127.5
53
+ cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
54
+ cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
55
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
56
+ cv2.imwrite(str(out_path), cartoon)
57
+ return out_path
p2c/test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from models import ResnetGenerator
6
+ import argparse
7
+ from utils import Preprocess
8
+
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--photo_path', type=str, help='input photo path')
12
+ parser.add_argument('--save_path', type=str, help='cartoon save path')
13
+ args = parser.parse_args()
14
+
15
+ os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
16
+
17
+ class Photo2Cartoon:
18
+ def __init__(self):
19
+ self.pre = Preprocess()
20
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ self.net = ResnetGenerator(ngf=32, img_size=256, light=True).to(self.device)
22
+
23
+ assert os.path.exists('./models/photo2cartoon_weights.pt'), "[Step1: load weights] Can not find 'photo2cartoon_weights.pt' in folder 'models!!!'"
24
+ params = torch.load('./models/photo2cartoon_weights.pt', map_location=self.device)
25
+ self.net.load_state_dict(params['genA2B'])
26
+ print('[Step1: load weights] success!')
27
+
28
+ def inference(self, img):
29
+ # face alignment and segmentation
30
+ face_rgba = self.pre.process(img)
31
+ if face_rgba is None:
32
+ print('[Step2: face detect] can not detect face!!!')
33
+ return None
34
+
35
+ print('[Step2: face detect] success!')
36
+ face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
37
+ face = face_rgba[:, :, :3].copy()
38
+ mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
39
+ face = (face*mask + (1-mask)*255) / 127.5 - 1
40
+
41
+ face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
42
+ face = torch.from_numpy(face).to(self.device)
43
+
44
+ # inference
45
+ with torch.no_grad():
46
+ cartoon = self.net(face)[0][0]
47
+
48
+ # post-process
49
+ cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
50
+ cartoon = (cartoon + 1) * 127.5
51
+ cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
52
+ cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
53
+ print('[Step3: photo to cartoon] success!')
54
+ return cartoon
55
+
56
+
57
+ if __name__ == '__main__':
58
+ img = cv2.cvtColor(cv2.imread(args.photo_path), cv2.COLOR_BGR2RGB)
59
+ c2p = Photo2Cartoon()
60
+ cartoon = c2p.inference(img)
61
+ if cartoon is not None:
62
+ cv2.imwrite(args.save_path, cartoon)
63
+ print('Cartoon portrait has been saved successfully!')
p2c/test_onnx.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ from utils import Preprocess
6
+
7
+
8
+ class Photo2Cartoon:
9
+ def __init__(self):
10
+ self.pre = Preprocess()
11
+ curPath = os.path.abspath(os.path.dirname(__file__))
12
+ # assert os.path.exists('./models/photo2cartoon_weights.onnx'), "[Step1: load weights] Can not find 'photo2cartoon_weights.onnx' in folder 'models!!!'"
13
+ self.session = onnxruntime.InferenceSession(os.path.join(curPath, 'models/photo2cartoon_weights.onnx'))
14
+ print('[Step1: load weights] success!')
15
+
16
+ def inference(self, in_path):
17
+ img = cv2.cvtColor(cv2.imread(in_path), cv2.COLOR_BGR2RGB)
18
+ # face alignment and segmentation
19
+ face_rgba = self.pre.process(img)
20
+ if face_rgba is None:
21
+ print('[Step2: face detect] can not detect face!!!')
22
+ return None
23
+
24
+ print('[Step2: face detect] success!')
25
+ face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
26
+ face = face_rgba[:, :, :3].copy()
27
+ mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
28
+ face = (face * mask + (1 - mask) * 255) / 127.5 - 1
29
+
30
+ face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
31
+
32
+ # inference
33
+ cartoon = self.session.run(['output'], input_feed={'input': face})
34
+
35
+ # post-process
36
+ cartoon = np.transpose(cartoon[0][0], (1, 2, 0))
37
+ cartoon = (cartoon + 1) * 127.5
38
+ cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
39
+ #cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
40
+
41
+ print('[Step3: photo to cartoon] success!')
42
+ return cartoon
43
+
44
+
45
+ if __name__ == '__main__':
46
+ c2p = Photo2Cartoon()
47
+ cartoon = c2p.inference('')
48
+ if cartoon is not None:
49
+ print('Cartoon portrait has been saved successfully!')
p2c/train.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import UgatitSadalinHourglass
2
+ import argparse
3
+ import shutil
4
+ from utils import *
5
+
6
+
7
+ def parse_args():
8
+ """parsing and configuration"""
9
+ desc = "photo2cartoon"
10
+ parser = argparse.ArgumentParser(description=desc)
11
+ parser.add_argument('--phase', type=str, default='train', help='[train / test]')
12
+ parser.add_argument('--light', type=str2bool, default=True, help='[U-GAT-IT full version / U-GAT-IT light version]')
13
+ parser.add_argument('--dataset', type=str, default='photo2cartoon', help='dataset name')
14
+
15
+ parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
16
+ parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
17
+ parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq')
18
+ parser.add_argument('--save_freq', type=int, default=1000, help='The number of model save freq')
19
+ parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
20
+
21
+ parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
22
+ parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN')
23
+ parser.add_argument('--cycle_weight', type=int, default=50, help='Weight for Cycle')
24
+ parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity')
25
+ parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM')
26
+ parser.add_argument('--faceid_weight', type=int, default=1, help='Weight for Face ID')
27
+
28
+ parser.add_argument('--ch', type=int, default=32, help='base channel number per layer')
29
+ parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
30
+
31
+ parser.add_argument('--img_size', type=int, default=256, help='The size of image')
32
+ parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
33
+
34
+ # parser.add_argument('--device', type=str, default='cuda:0', help='Set gpu mode: [cpu, cuda]')
35
+ parser.add_argument('--gpu_ids', type=int, default=[0], nargs='+', help='Set [0, 1, 2, 3] for multi-gpu training')
36
+ parser.add_argument('--benchmark_flag', type=str2bool, default=False)
37
+ parser.add_argument('--resume', type=str2bool, default=False)
38
+ parser.add_argument('--rho_clipper', type=float, default=1.0)
39
+ parser.add_argument('--w_clipper', type=float, default=1.0)
40
+ parser.add_argument('--pretrained_weights', type=str, default='', help='pretrained weight path')
41
+
42
+ args = parser.parse_args()
43
+ args.result_dir = './experiment/{}-size{}-ch{}-{}-lr{}-adv{}-cyc{}-id{}-identity{}-cam{}'.format(
44
+ os.path.basename(__file__)[:-3],
45
+ args.img_size,
46
+ args.ch,
47
+ args.light,
48
+ args.lr,
49
+ args.adv_weight,
50
+ args.cycle_weight,
51
+ args.faceid_weight,
52
+ args.identity_weight,
53
+ args.cam_weight)
54
+
55
+ return check_args(args)
56
+
57
+
58
+ def check_args(args):
59
+ check_folder(os.path.join(args.result_dir, args.dataset, 'model'))
60
+ check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
61
+ check_folder(os.path.join(args.result_dir, args.dataset, 'test'))
62
+ shutil.copy(__file__, args.result_dir)
63
+ return args
64
+
65
+
66
+ def main():
67
+ args = parse_args()
68
+ if args is None:
69
+ exit()
70
+
71
+ gan = UgatitSadalinHourglass(args)
72
+ gan.build_model()
73
+
74
+ if args.phase == 'train':
75
+ gan.train()
76
+ print(" [*] Training finished!")
77
+
78
+ if args.phase == 'test':
79
+ gan.test()
80
+ print(" [*] Test finished!")
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
p2c/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .preprocess import Preprocess
2
+ from .utils import *
p2c/utils/face_detect.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import face_alignment
5
+
6
+
7
+ class FaceDetect:
8
+ def __init__(self, device, detector):
9
+ # landmarks will be detected by face_alignment library. Set device = 'cuda' if use GPU.
10
+ self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device, face_detector=detector)
11
+
12
+ def align(self, image):
13
+ landmarks = self.__get_max_face_landmarks(image)
14
+
15
+ if landmarks is None:
16
+ return None
17
+
18
+ else:
19
+ return self.__rotate(image, landmarks)
20
+
21
+ def __get_max_face_landmarks(self, image):
22
+ preds = self.fa.get_landmarks(image)
23
+ if preds is None:
24
+ return None
25
+
26
+ elif len(preds) == 1:
27
+ return preds[0]
28
+
29
+ else:
30
+ # find max face
31
+ areas = []
32
+ for pred in preds:
33
+ landmarks_top = np.min(pred[:, 1])
34
+ landmarks_bottom = np.max(pred[:, 1])
35
+ landmarks_left = np.min(pred[:, 0])
36
+ landmarks_right = np.max(pred[:, 0])
37
+ areas.append((landmarks_bottom - landmarks_top) * (landmarks_right - landmarks_left))
38
+ max_face_index = np.argmax(areas)
39
+ return preds[max_face_index]
40
+
41
+ @staticmethod
42
+ def __rotate(image, landmarks):
43
+ # rotation angle
44
+ left_eye_corner = landmarks[36]
45
+ right_eye_corner = landmarks[45]
46
+ radian = np.arctan((left_eye_corner[1] - right_eye_corner[1]) / (left_eye_corner[0] - right_eye_corner[0]))
47
+
48
+ # image size after rotating
49
+ height, width, _ = image.shape
50
+ cos = math.cos(radian)
51
+ sin = math.sin(radian)
52
+ new_w = int(width * abs(cos) + height * abs(sin))
53
+ new_h = int(width * abs(sin) + height * abs(cos))
54
+
55
+ # translation
56
+ Tx = new_w // 2 - width // 2
57
+ Ty = new_h // 2 - height // 2
58
+
59
+ # affine matrix
60
+ M = np.array([[cos, sin, (1 - cos) * width / 2. - sin * height / 2. + Tx],
61
+ [-sin, cos, sin * width / 2. + (1 - cos) * height / 2. + Ty]])
62
+
63
+ image_rotate = cv2.warpAffine(image, M, (new_w, new_h), borderValue=(255, 255, 255))
64
+
65
+ landmarks = np.concatenate([landmarks, np.ones((landmarks.shape[0], 1))], axis=1)
66
+ landmarks_rotate = np.dot(M, landmarks.T).T
67
+ return image_rotate, landmarks_rotate
68
+
69
+
70
+ if __name__ == '__main__':
71
+ img = cv2.cvtColor(cv2.imread('3989161_1.jpg'), cv2.COLOR_BGR2RGB)
72
+ fd = FaceDetect(device='cpu')
73
+ face_info = fd.align(img)
74
+ if face_info is not None:
75
+ image_align, landmarks_align = face_info
76
+
77
+ for i in range(landmarks_align.shape[0]):
78
+ cv2.circle(image_align, (int(landmarks_align[i][0]), int(landmarks_align[i][1])), 2, (255, 0, 0), -1)
79
+
80
+ cv2.imwrite('image_align.png', cv2.cvtColor(image_align, cv2.COLOR_RGB2BGR))
p2c/utils/face_seg.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.python.platform import gfile
6
+
7
+
8
+ curPath = os.path.abspath(os.path.dirname(__file__))
9
+
10
+
11
+ class FaceSeg:
12
+ def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
13
+ config = tf.compat.v1.ConfigProto()
14
+ config.gpu_options.allow_growth = True
15
+ self._graph = tf.Graph()
16
+ self._sess = tf.compat.v1.Session(config=config, graph=self._graph)
17
+
18
+ self.pb_file_path = model_path
19
+ self._restore_from_pb()
20
+ self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
21
+ self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
22
+
23
+ def _restore_from_pb(self):
24
+ with self._sess.as_default():
25
+ with self._graph.as_default():
26
+ with gfile.FastGFile(self.pb_file_path, 'rb') as f:
27
+ graph_def = tf.compat.v1.GraphDef()
28
+ graph_def.ParseFromString(f.read())
29
+ tf.import_graph_def(graph_def, name='')
30
+
31
+ def input_transform(self, image):
32
+ image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
33
+ image_input = (image / 255.)[np.newaxis, :, :, :]
34
+ return image_input
35
+
36
+ def output_transform(self, output, shape):
37
+ output = cv2.resize(output, (shape[1], shape[0]))
38
+ image_output = (output * 255).astype(np.uint8)
39
+ return image_output
40
+
41
+ def get_mask(self, image):
42
+ image_input = self.input_transform(image)
43
+ output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
44
+ return self.output_transform(output, shape=image.shape[:2])
p2c/utils/preprocess.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .face_detect import FaceDetect
2
+ from .face_seg import FaceSeg
3
+ import numpy as np
4
+
5
+
6
+ class Preprocess:
7
+ def __init__(self, device='cpu', detector='dlib'):
8
+ self.detect = FaceDetect(device, detector) # device = 'cpu' or 'cuda', detector = 'dlib' or 'sfd'
9
+ self.segment = FaceSeg()
10
+
11
+ def process(self, image):
12
+ face_info = self.detect.align(image)
13
+ if face_info is None:
14
+ return None
15
+ image_align, landmarks_align = face_info
16
+
17
+ face = self.__crop(image_align, landmarks_align)
18
+ mask = self.segment.get_mask(face)
19
+ return np.dstack((face, mask))
20
+
21
+ @staticmethod
22
+ def __crop(image, landmarks):
23
+ landmarks_top = np.min(landmarks[:, 1])
24
+ landmarks_bottom = np.max(landmarks[:, 1])
25
+ landmarks_left = np.min(landmarks[:, 0])
26
+ landmarks_right = np.max(landmarks[:, 0])
27
+
28
+ # expand bbox
29
+ top = int(landmarks_top - 0.8 * (landmarks_bottom - landmarks_top))
30
+ bottom = int(landmarks_bottom + 0.3 * (landmarks_bottom - landmarks_top))
31
+ left = int(landmarks_left - 0.3 * (landmarks_right - landmarks_left))
32
+ right = int(landmarks_right + 0.3 * (landmarks_right - landmarks_left))
33
+
34
+ if bottom - top > right - left:
35
+ left -= ((bottom - top) - (right - left)) // 2
36
+ right = left + (bottom - top)
37
+ else:
38
+ top -= ((right - left) - (bottom - top)) // 2
39
+ bottom = top + (right - left)
40
+
41
+ image_crop = np.ones((bottom - top + 1, right - left + 1, 3), np.uint8) * 255
42
+
43
+ h, w = image.shape[:2]
44
+ left_white = max(0, -left)
45
+ left = max(0, left)
46
+ right = min(right, w-1)
47
+ right_white = left_white + (right-left)
48
+ top_white = max(0, -top)
49
+ top = max(0, top)
50
+ bottom = min(bottom, h-1)
51
+ bottom_white = top_white + (bottom - top)
52
+
53
+ image_crop[top_white:bottom_white+1, left_white:right_white+1] = image[top:bottom+1, left:right+1].copy()
54
+ return image_crop
p2c/utils/seg_model_384.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a04bc2032b54013d2ae994b34d22518144276f1cbdd2d8cbb1a4a28f50285f
3
+ size 32477258
p2c/utils/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from scipy import misc
6
+
7
+
8
+ def load_test_data(image_path, size=256):
9
+ img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
10
+ if img is None:
11
+ return None
12
+
13
+ h, w, c = img.shape
14
+ if img.shape[2] == 4:
15
+ white = np.ones((h, w, 3), np.uint8) * 255
16
+ img_rgb = img[:, :, :3].copy()
17
+ mask = img[:, :, 3].copy()
18
+ mask = (mask / 255).astype(np.uint8)
19
+ img = (img_rgb * mask[:, :, np.newaxis]).astype(np.uint8) + white * (1 - mask[:, :, np.newaxis])
20
+
21
+ img = cv2.resize(img, (size, size), cv2.INTER_AREA)
22
+ img = RGB2BGR(img)
23
+
24
+ img = np.expand_dims(img, axis=0)
25
+ img = preprocessing(img)
26
+ return img
27
+
28
+
29
+ def preprocessing(x):
30
+ x = x/127.5 - 1
31
+ # -1 ~ 1
32
+ return x
33
+
34
+
35
+ def save_images(images, size, image_path):
36
+ return imsave(inverse_transform(images), size, image_path)
37
+
38
+
39
+ def inverse_transform(images):
40
+ return (images+1.) / 2
41
+
42
+
43
+ def imsave(images, size, path):
44
+ return misc.imsave(path, merge(images, size))
45
+
46
+
47
+ def merge(images, size):
48
+ h, w = images.shape[1], images.shape[2]
49
+ img = np.zeros((h * size[0], w * size[1], 3))
50
+ for idx, image in enumerate(images):
51
+ i = idx % size[1]
52
+ j = idx // size[1]
53
+ img[h*j:h*(j+1), w*i:w*(i+1), :] = image
54
+
55
+ return img
56
+
57
+
58
+ def check_folder(log_dir):
59
+ if not os.path.exists(log_dir):
60
+ os.makedirs(log_dir)
61
+ return log_dir
62
+
63
+
64
+ def str2bool(x):
65
+ return x.lower() in ('true')
66
+
67
+
68
+ def cam(x, size=256):
69
+ x = x - np.min(x)
70
+ cam_img = x / np.max(x)
71
+ cam_img = np.uint8(255 * cam_img)
72
+ cam_img = cv2.resize(cam_img, (size, size))
73
+ cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET)
74
+ return cam_img / 255.0
75
+
76
+
77
+ def imagenet_norm(x):
78
+ mean = [0.485, 0.456, 0.406]
79
+ std = [0.299, 0.224, 0.225]
80
+ mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
81
+ std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
82
+ return (x - mean) / std
83
+
84
+
85
+ def denorm(x):
86
+ return x * 0.5 + 0.5
87
+
88
+
89
+ def tensor2numpy(x):
90
+ return x.detach().cpu().numpy().transpose(1, 2, 0)
91
+
92
+
93
+ def RGB2BGR(x):
94
+ return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cmake
2
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python-headless==4.5.5.62
2
+ Pillow==9.0.1
3
+ scipy==1.7.3
4
+ tensorflow-gpu==1.14.0
5
+ scikit-image==0.14.5
6
+ onnxruntime
7
+ face-alignment
8
+ dlib
9
+