Spaces:
Runtime error
Runtime error
init
Browse files- README.md +1 -0
- app.py +100 -0
- p2c/cog.yaml +27 -0
- p2c/data_process.py +30 -0
- p2c/dataset.py +108 -0
- p2c/dataset/README.md +20 -0
- p2c/images/QRcode.jpg +0 -0
- p2c/images/data_process.jpg +0 -0
- p2c/images/photo_test.jpg +0 -0
- p2c/images/results.png +0 -0
- p2c/images/title.png +0 -0
- p2c/models/UGATIT_sadalin_hourglass.py +489 -0
- p2c/models/__init__.py +3 -0
- p2c/models/face_features.py +31 -0
- p2c/models/mobilefacenet.py +258 -0
- p2c/models/model_mobilefacenet.pth +3 -0
- p2c/models/networks.py +485 -0
- p2c/models/photo2cartoon_weights.onnx +3 -0
- p2c/models/photo2cartoon_weights.pt +3 -0
- p2c/predict.py +57 -0
- p2c/test.py +63 -0
- p2c/test_onnx.py +49 -0
- p2c/train.py +84 -0
- p2c/utils/__init__.py +2 -0
- p2c/utils/face_detect.py +80 -0
- p2c/utils/face_seg.py +44 -0
- p2c/utils/preprocess.py +54 -0
- p2c/utils/seg_model_384.pb +3 -0
- p2c/utils/utils.py +94 -0
- packages.txt +2 -0
- requirements.txt +9 -0
README.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
---
|
|
|
2 |
title: Photo2cartoon
|
3 |
emoji: π
|
4 |
colorFrom: gray
|
|
|
1 |
---
|
2 |
+
python_version: 3.7
|
3 |
title: Photo2cartoon
|
4 |
emoji: π
|
5 |
colorFrom: gray
|
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
import argparse
|
5 |
+
import functools
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import sys
|
9 |
+
from typing import Callable
|
10 |
+
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import huggingface_hub
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
|
17 |
+
import cv2
|
18 |
+
|
19 |
+
from io import BytesIO
|
20 |
+
sys.path.insert(0, 'p2c')
|
21 |
+
|
22 |
+
from test_onnx import Photo2Cartoon
|
23 |
+
|
24 |
+
|
25 |
+
ORIGINAL_REPO_URL = 'https://github.com/minivision-ai/photo2cartoon'
|
26 |
+
TITLE = 'minivision-ai/photo2cartoon'
|
27 |
+
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
|
28 |
+
|
29 |
+
"""
|
30 |
+
ARTICLE = """
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def parse_args() -> argparse.Namespace:
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument('--device', type=str, default='cpu')
|
39 |
+
parser.add_argument('--theme', type=str)
|
40 |
+
parser.add_argument('--live', action='store_true')
|
41 |
+
parser.add_argument('--share', action='store_true')
|
42 |
+
parser.add_argument('--port', type=int)
|
43 |
+
parser.add_argument('--disable-queue',
|
44 |
+
dest='enable_queue',
|
45 |
+
action='store_false')
|
46 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
47 |
+
parser.add_argument('--allow-screenshot', action='store_true')
|
48 |
+
return parser.parse_args()
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def run(
|
54 |
+
image,
|
55 |
+
p2c : Photo2Cartoon,
|
56 |
+
) -> tuple[PIL.Image.Image]:
|
57 |
+
|
58 |
+
cartoon = p2c.inference(image.name)
|
59 |
+
|
60 |
+
return PIL.Image.fromarray(cartoon)
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
gr.close_all()
|
65 |
+
|
66 |
+
args = parse_args()
|
67 |
+
|
68 |
+
p2c = Photo2Cartoon()
|
69 |
+
|
70 |
+
func = functools.partial(run, p2c)
|
71 |
+
func = functools.update_wrapper(func, run)
|
72 |
+
|
73 |
+
|
74 |
+
gr.Interface(
|
75 |
+
func,
|
76 |
+
[
|
77 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
78 |
+
],
|
79 |
+
[
|
80 |
+
gr.outputs.Image(
|
81 |
+
type='pil',
|
82 |
+
label='Result'),
|
83 |
+
],
|
84 |
+
#examples=examples,
|
85 |
+
theme=args.theme,
|
86 |
+
title=TITLE,
|
87 |
+
description=DESCRIPTION,
|
88 |
+
article=ARTICLE,
|
89 |
+
allow_screenshot=args.allow_screenshot,
|
90 |
+
allow_flagging=args.allow_flagging,
|
91 |
+
live=args.live,
|
92 |
+
).launch(
|
93 |
+
enable_queue=args.enable_queue,
|
94 |
+
server_port=args.port,
|
95 |
+
share=args.share,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
main()
|
p2c/cog.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
predict: "predict.py:Predictor"
|
2 |
+
build:
|
3 |
+
python_version: "3.8"
|
4 |
+
system_packages:
|
5 |
+
- "libgl1-mesa-glx"
|
6 |
+
- "libglib2.0-0"
|
7 |
+
python_packages:
|
8 |
+
- "cmake==3.21.1"
|
9 |
+
- "torch==1.8.0"
|
10 |
+
- "torchvision==0.9.0"
|
11 |
+
- "numpy==1.19.2"
|
12 |
+
- "ipython==7.21.0"
|
13 |
+
- "opencv-python==4.3.0.38"
|
14 |
+
- "face-alignment==1.3.4"
|
15 |
+
- "tensorflow-gpu==2.5.0"
|
16 |
+
pre_install:
|
17 |
+
- pip install dlib
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
p2c/data_process.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from utils import Preprocess
|
8 |
+
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--data_path', type=str, help='photo folder path')
|
12 |
+
parser.add_argument('--save_path', type=str, help='save folder path')
|
13 |
+
|
14 |
+
args = parser.parse_args()
|
15 |
+
os.makedirs(args.save_path, exist_ok=True)
|
16 |
+
|
17 |
+
pre = Preprocess()
|
18 |
+
|
19 |
+
for idx, img_name in enumerate(tqdm(os.listdir(args.data_path))):
|
20 |
+
img = cv2.cvtColor(cv2.imread(os.path.join(args.data_path, img_name)), cv2.COLOR_BGR2RGB)
|
21 |
+
|
22 |
+
# face alignment and segmentation
|
23 |
+
face_rgba = pre.process(img)
|
24 |
+
if face_rgba is not None:
|
25 |
+
# change background to white
|
26 |
+
face = face_rgba[:,:,:3].copy()
|
27 |
+
mask = face_rgba[:,:,3].copy()[:,:,np.newaxis]/255.
|
28 |
+
face_white_bg = (face*mask + (1-mask)*255).astype(np.uint8)
|
29 |
+
|
30 |
+
cv2.imwrite(os.path.join(args.save_path, str(idx).zfill(4)+'.png'), cv2.cvtColor(face_white_bg, cv2.COLOR_RGB2BGR))
|
p2c/dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import os
|
6 |
+
import os.path
|
7 |
+
|
8 |
+
|
9 |
+
def has_file_allowed_extension(filename, extensions):
|
10 |
+
"""Checks if a file is an allowed extension.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
filename (string): path to a file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
bool: True if the filename ends with a known image extension
|
17 |
+
"""
|
18 |
+
filename_lower = filename.lower()
|
19 |
+
return any(filename_lower.endswith(ext) for ext in extensions)
|
20 |
+
|
21 |
+
|
22 |
+
def find_classes(dir):
|
23 |
+
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
|
24 |
+
classes.sort()
|
25 |
+
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
26 |
+
return classes, class_to_idx
|
27 |
+
|
28 |
+
|
29 |
+
def make_dataset(dir, extensions):
|
30 |
+
images = []
|
31 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
32 |
+
for fname in sorted(fnames):
|
33 |
+
if has_file_allowed_extension(fname, extensions):
|
34 |
+
path = os.path.join(root, fname)
|
35 |
+
item = (path, 0)
|
36 |
+
images.append(item)
|
37 |
+
|
38 |
+
return images
|
39 |
+
|
40 |
+
|
41 |
+
class DatasetFolder(data.Dataset):
|
42 |
+
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
|
43 |
+
# classes, class_to_idx = find_classes(root)
|
44 |
+
samples = make_dataset(root, extensions)
|
45 |
+
if len(samples) == 0:
|
46 |
+
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
|
47 |
+
"Supported extensions are: " + ",".join(extensions)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.loader = loader
|
51 |
+
self.extensions = extensions
|
52 |
+
self.samples = samples
|
53 |
+
|
54 |
+
self.transform = transform
|
55 |
+
self.target_transform = target_transform
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
index (int): Index
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
tuple: (sample, target) where target is class_index of the target class.
|
64 |
+
"""
|
65 |
+
path, target = self.samples[index]
|
66 |
+
sample = self.loader(path)
|
67 |
+
if self.transform is not None:
|
68 |
+
sample = self.transform(sample)
|
69 |
+
if self.target_transform is not None:
|
70 |
+
target = self.target_transform(target)
|
71 |
+
|
72 |
+
return sample, target
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.samples)
|
76 |
+
|
77 |
+
def __repr__(self):
|
78 |
+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
79 |
+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
80 |
+
fmt_str += ' Root Location: {}\n'.format(self.root)
|
81 |
+
tmp = ' Transforms (if any): '
|
82 |
+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
83 |
+
tmp = ' Target Transforms (if any): '
|
84 |
+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
85 |
+
return fmt_str
|
86 |
+
|
87 |
+
|
88 |
+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
|
89 |
+
|
90 |
+
|
91 |
+
def pil_loader(path):
|
92 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
93 |
+
with open(path, 'rb') as f:
|
94 |
+
img = Image.open(f)
|
95 |
+
return img.convert('RGB')
|
96 |
+
|
97 |
+
|
98 |
+
def default_loader(path):
|
99 |
+
return pil_loader(path)
|
100 |
+
|
101 |
+
|
102 |
+
class ImageFolder(DatasetFolder):
|
103 |
+
def __init__(self, root, transform=None, target_transform=None,
|
104 |
+
loader=default_loader):
|
105 |
+
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
|
106 |
+
transform=transform,
|
107 |
+
target_transform=target_transform)
|
108 |
+
self.imgs = self.samples
|
p2c/dataset/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```
|
2 |
+
βββ dataset
|
3 |
+
βββ photo2cartoon
|
4 |
+
βββ trainA
|
5 |
+
βββ xxx.jpg
|
6 |
+
βββ yyy.png
|
7 |
+
βββ ...
|
8 |
+
βββ trainB
|
9 |
+
βββ zzz.jpg
|
10 |
+
βββ www.png
|
11 |
+
βββ ...
|
12 |
+
βββ testA
|
13 |
+
βββ aaa.jpg
|
14 |
+
βββ bbb.png
|
15 |
+
βββ ...
|
16 |
+
βββ testB
|
17 |
+
βββ ccc.jpg
|
18 |
+
βββ ddd.png
|
19 |
+
βββ ...
|
20 |
+
```
|
p2c/images/QRcode.jpg
ADDED
p2c/images/data_process.jpg
ADDED
p2c/images/photo_test.jpg
ADDED
p2c/images/results.png
ADDED
p2c/images/title.png
ADDED
p2c/models/UGATIT_sadalin_hourglass.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import itertools
|
3 |
+
from dataset import ImageFolder
|
4 |
+
from torchvision import transforms
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from .networks import *
|
7 |
+
from utils import *
|
8 |
+
from glob import glob
|
9 |
+
from .face_features import FaceFeatures
|
10 |
+
|
11 |
+
|
12 |
+
class UgatitSadalinHourglass(object):
|
13 |
+
def __init__(self, args):
|
14 |
+
self.light = args.light
|
15 |
+
|
16 |
+
if self.light:
|
17 |
+
self.model_name = 'UGATIT_light'
|
18 |
+
else:
|
19 |
+
self.model_name = 'UGATIT'
|
20 |
+
|
21 |
+
self.result_dir = args.result_dir
|
22 |
+
self.dataset = args.dataset
|
23 |
+
|
24 |
+
self.iteration = args.iteration
|
25 |
+
self.decay_flag = args.decay_flag
|
26 |
+
|
27 |
+
self.batch_size = args.batch_size
|
28 |
+
self.print_freq = args.print_freq
|
29 |
+
self.save_freq = args.save_freq
|
30 |
+
|
31 |
+
self.lr = args.lr
|
32 |
+
self.ch = args.ch
|
33 |
+
|
34 |
+
""" Weight """
|
35 |
+
self.adv_weight = args.adv_weight
|
36 |
+
self.cycle_weight = args.cycle_weight
|
37 |
+
self.identity_weight = args.identity_weight
|
38 |
+
self.cam_weight = args.cam_weight
|
39 |
+
self.faceid_weight = args.faceid_weight
|
40 |
+
|
41 |
+
""" Discriminator """
|
42 |
+
self.n_dis = args.n_dis
|
43 |
+
|
44 |
+
self.img_size = args.img_size
|
45 |
+
self.img_ch = args.img_ch
|
46 |
+
|
47 |
+
self.device = f'cuda:{args.gpu_ids[0]}'
|
48 |
+
self.gpu_ids = args.gpu_ids
|
49 |
+
self.benchmark_flag = args.benchmark_flag
|
50 |
+
self.resume = args.resume
|
51 |
+
self.rho_clipper = args.rho_clipper
|
52 |
+
self.w_clipper = args.w_clipper
|
53 |
+
self.pretrained_weights = args.pretrained_weights
|
54 |
+
|
55 |
+
if torch.backends.cudnn.enabled and self.benchmark_flag:
|
56 |
+
print('set benchmark !')
|
57 |
+
torch.backends.cudnn.benchmark = True
|
58 |
+
|
59 |
+
print("##### Information #####")
|
60 |
+
print("# light : ", self.light)
|
61 |
+
print("# dataset : ", self.dataset)
|
62 |
+
print("# batch_size : ", self.batch_size)
|
63 |
+
print("# iteration per epoch : ", self.iteration)
|
64 |
+
|
65 |
+
print("##### Discriminator #####")
|
66 |
+
print("# discriminator layer : ", self.n_dis)
|
67 |
+
|
68 |
+
print()
|
69 |
+
|
70 |
+
print("##### Weight #####")
|
71 |
+
print("# adv_weight : ", self.adv_weight)
|
72 |
+
print("# cycle_weight : ", self.cycle_weight)
|
73 |
+
print("# faceid_weight : ", self.faceid_weight)
|
74 |
+
print("# identity_weight : ", self.identity_weight)
|
75 |
+
print("# cam_weight : ", self.cam_weight)
|
76 |
+
print("# rho_clipper: ", self.rho_clipper)
|
77 |
+
print("# w_clipper: ", self.w_clipper)
|
78 |
+
|
79 |
+
##################################################################################
|
80 |
+
# Model
|
81 |
+
##################################################################################
|
82 |
+
|
83 |
+
def build_model(self):
|
84 |
+
""" DataLoader """
|
85 |
+
train_transform = transforms.Compose([
|
86 |
+
transforms.RandomHorizontalFlip(),
|
87 |
+
transforms.Resize((self.img_size + 30, self.img_size+30)),
|
88 |
+
transforms.RandomCrop(self.img_size),
|
89 |
+
transforms.ToTensor(),
|
90 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
91 |
+
])
|
92 |
+
test_transform = transforms.Compose([
|
93 |
+
transforms.Resize((self.img_size, self.img_size)),
|
94 |
+
transforms.ToTensor(),
|
95 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
96 |
+
])
|
97 |
+
self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
|
98 |
+
self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
|
99 |
+
self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
|
100 |
+
self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
|
101 |
+
|
102 |
+
self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
|
103 |
+
self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
|
104 |
+
self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
|
105 |
+
self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
|
106 |
+
|
107 |
+
""" Define Generator, Discriminator """
|
108 |
+
self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
|
109 |
+
self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device)
|
110 |
+
self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
|
111 |
+
self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
|
112 |
+
self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
|
113 |
+
self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
|
114 |
+
|
115 |
+
self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device)
|
116 |
+
|
117 |
+
""" Define Loss """
|
118 |
+
self.L1_loss = nn.L1Loss().to(self.device)
|
119 |
+
self.MSE_loss = nn.MSELoss().to(self.device)
|
120 |
+
self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)
|
121 |
+
|
122 |
+
""" Trainer """
|
123 |
+
self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001)
|
124 |
+
self.D_optim = torch.optim.Adam(
|
125 |
+
itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()),
|
126 |
+
lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001
|
127 |
+
)
|
128 |
+
|
129 |
+
""" Define Rho clipper to constraint the value of rho in AdaLIN and LIN"""
|
130 |
+
self.Rho_clipper = RhoClipper(0, self.rho_clipper)
|
131 |
+
self.W_Clipper = WClipper(0, self.w_clipper)
|
132 |
+
|
133 |
+
def train(self):
|
134 |
+
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
|
135 |
+
|
136 |
+
start_iter = 1
|
137 |
+
if self.resume:
|
138 |
+
model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
|
139 |
+
if not len(model_list) == 0:
|
140 |
+
model_list.sort()
|
141 |
+
start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
|
142 |
+
self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter)
|
143 |
+
print(" [*] Load SUCCESS")
|
144 |
+
if self.decay_flag and start_iter > (self.iteration // 2):
|
145 |
+
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
|
146 |
+
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2)
|
147 |
+
|
148 |
+
if self.pretrained_weights:
|
149 |
+
params = torch.load(self.pretrained_weights, map_location=self.device)
|
150 |
+
self.genA2B.load_state_dict(params['genA2B'])
|
151 |
+
self.genB2A.load_state_dict(params['genB2A'])
|
152 |
+
self.disGA.load_state_dict(params['disGA'])
|
153 |
+
self.disGB.load_state_dict(params['disGB'])
|
154 |
+
self.disLA.load_state_dict(params['disLA'])
|
155 |
+
self.disLB.load_state_dict(params['disLB'])
|
156 |
+
print(" [*] Load {} Success".format(self.pretrained_weights))
|
157 |
+
|
158 |
+
if len(self.gpu_ids) > 1:
|
159 |
+
self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids)
|
160 |
+
self.genB2A = nn.DataParallel(self.genB2A, device_ids=self.gpu_ids)
|
161 |
+
self.disGA = nn.DataParallel(self.disGA, device_ids=self.gpu_ids)
|
162 |
+
self.disGB = nn.DataParallel(self.disGB, device_ids=self.gpu_ids)
|
163 |
+
self.disLA = nn.DataParallel(self.disLA, device_ids=self.gpu_ids)
|
164 |
+
self.disLB = nn.DataParallel(self.disLB, device_ids=self.gpu_ids)
|
165 |
+
|
166 |
+
# training loop
|
167 |
+
print('training start !')
|
168 |
+
start_time = time.time()
|
169 |
+
for step in range(start_iter, self.iteration + 1):
|
170 |
+
if self.decay_flag and step > (self.iteration // 2):
|
171 |
+
self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
|
172 |
+
self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2))
|
173 |
+
|
174 |
+
try:
|
175 |
+
real_A, _ = trainA_iter.next()
|
176 |
+
except:
|
177 |
+
trainA_iter = iter(self.trainA_loader)
|
178 |
+
real_A, _ = trainA_iter.next()
|
179 |
+
|
180 |
+
try:
|
181 |
+
real_B, _ = trainB_iter.next()
|
182 |
+
except:
|
183 |
+
trainB_iter = iter(self.trainB_loader)
|
184 |
+
real_B, _ = trainB_iter.next()
|
185 |
+
|
186 |
+
real_A, real_B = real_A.to(self.device), real_B.to(self.device)
|
187 |
+
|
188 |
+
# Update D
|
189 |
+
self.D_optim.zero_grad()
|
190 |
+
|
191 |
+
fake_A2B, _, _ = self.genA2B(real_A)
|
192 |
+
fake_B2A, _, _ = self.genB2A(real_B)
|
193 |
+
|
194 |
+
real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A)
|
195 |
+
real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A)
|
196 |
+
real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
|
197 |
+
real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)
|
198 |
+
|
199 |
+
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
|
200 |
+
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
|
201 |
+
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
|
202 |
+
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
|
203 |
+
|
204 |
+
D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + \
|
205 |
+
self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
|
206 |
+
|
207 |
+
D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + \
|
208 |
+
self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
|
209 |
+
|
210 |
+
D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + \
|
211 |
+
self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device))
|
212 |
+
|
213 |
+
D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) +\
|
214 |
+
self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device))
|
215 |
+
|
216 |
+
D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + \
|
217 |
+
self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device))
|
218 |
+
|
219 |
+
D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + \
|
220 |
+
self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device))
|
221 |
+
|
222 |
+
D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + \
|
223 |
+
self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device))
|
224 |
+
|
225 |
+
D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) +\
|
226 |
+
self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device))
|
227 |
+
|
228 |
+
D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA)
|
229 |
+
D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB)
|
230 |
+
|
231 |
+
Discriminator_loss = D_loss_A + D_loss_B
|
232 |
+
Discriminator_loss.backward()
|
233 |
+
self.D_optim.step()
|
234 |
+
|
235 |
+
# Update G
|
236 |
+
self.G_optim.zero_grad()
|
237 |
+
|
238 |
+
fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
|
239 |
+
fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
|
240 |
+
|
241 |
+
fake_A2B2A, _, _ = self.genB2A(fake_A2B)
|
242 |
+
fake_B2A2B, _, _ = self.genA2B(fake_B2A)
|
243 |
+
|
244 |
+
fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
|
245 |
+
fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)
|
246 |
+
|
247 |
+
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
|
248 |
+
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
|
249 |
+
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
|
250 |
+
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
|
251 |
+
|
252 |
+
G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
|
253 |
+
G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
|
254 |
+
G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device))
|
255 |
+
G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device))
|
256 |
+
G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device))
|
257 |
+
G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device))
|
258 |
+
G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device))
|
259 |
+
G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device))
|
260 |
+
|
261 |
+
G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A)
|
262 |
+
G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B)
|
263 |
+
|
264 |
+
G_identity_loss_A = self.L1_loss(fake_A2A, real_A)
|
265 |
+
G_identity_loss_B = self.L1_loss(fake_B2B, real_B)
|
266 |
+
|
267 |
+
G_id_loss_A = self.facenet.cosine_distance(real_A, fake_A2B)
|
268 |
+
G_id_loss_B = self.facenet.cosine_distance(real_B, fake_B2A)
|
269 |
+
if len(self.gpu_ids) > 1:
|
270 |
+
G_id_loss_A = torch.mean(G_id_loss_A)
|
271 |
+
G_id_loss_B = torch.mean(G_id_loss_B)
|
272 |
+
|
273 |
+
G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + \
|
274 |
+
self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
|
275 |
+
G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + \
|
276 |
+
self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))
|
277 |
+
|
278 |
+
G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \
|
279 |
+
self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + \
|
280 |
+
self.cam_weight * G_cam_loss_A + self.faceid_weight * G_id_loss_A
|
281 |
+
G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \
|
282 |
+
self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + \
|
283 |
+
self.cam_weight * G_cam_loss_B + self.faceid_weight * G_id_loss_B
|
284 |
+
|
285 |
+
Generator_loss = G_loss_A + G_loss_B
|
286 |
+
Generator_loss.backward()
|
287 |
+
self.G_optim.step()
|
288 |
+
|
289 |
+
# clip parameter of Soft-AdaLIN and LIN, applied after optimizer step
|
290 |
+
self.genA2B.apply(self.Rho_clipper)
|
291 |
+
self.genB2A.apply(self.Rho_clipper)
|
292 |
+
|
293 |
+
self.genA2B.apply(self.W_Clipper)
|
294 |
+
self.genB2A.apply(self.W_Clipper)
|
295 |
+
|
296 |
+
if step % 10 == 0:
|
297 |
+
print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss))
|
298 |
+
if step % self.print_freq == 0:
|
299 |
+
train_sample_num = 5
|
300 |
+
test_sample_num = 5
|
301 |
+
A2B = np.zeros((self.img_size * 7, 0, 3))
|
302 |
+
B2A = np.zeros((self.img_size * 7, 0, 3))
|
303 |
+
|
304 |
+
self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval()
|
305 |
+
with torch.no_grad():
|
306 |
+
for _ in range(train_sample_num):
|
307 |
+
try:
|
308 |
+
real_A, _ = trainA_iter.next()
|
309 |
+
except:
|
310 |
+
trainA_iter = iter(self.trainA_loader)
|
311 |
+
real_A, _ = trainA_iter.next()
|
312 |
+
|
313 |
+
try:
|
314 |
+
real_B, _ = trainB_iter.next()
|
315 |
+
except:
|
316 |
+
trainB_iter = iter(self.trainB_loader)
|
317 |
+
real_B, _ = trainB_iter.next()
|
318 |
+
real_A, real_B = real_A.to(self.device), real_B.to(self.device)
|
319 |
+
|
320 |
+
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
|
321 |
+
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
|
322 |
+
|
323 |
+
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
|
324 |
+
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
|
325 |
+
|
326 |
+
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
|
327 |
+
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
|
328 |
+
|
329 |
+
A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
|
330 |
+
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
|
331 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
|
332 |
+
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
|
333 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
|
334 |
+
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
|
335 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
|
336 |
+
|
337 |
+
B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
|
338 |
+
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
|
339 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
|
340 |
+
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
|
341 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
|
342 |
+
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
|
343 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
|
344 |
+
|
345 |
+
for _ in range(test_sample_num):
|
346 |
+
try:
|
347 |
+
real_A, _ = testA_iter.next()
|
348 |
+
except:
|
349 |
+
testA_iter = iter(self.testA_loader)
|
350 |
+
real_A, _ = testA_iter.next()
|
351 |
+
|
352 |
+
try:
|
353 |
+
real_B, _ = testB_iter.next()
|
354 |
+
except:
|
355 |
+
testB_iter = iter(self.testB_loader)
|
356 |
+
real_B, _ = testB_iter.next()
|
357 |
+
real_A, real_B = real_A.to(self.device), real_B.to(self.device)
|
358 |
+
|
359 |
+
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
|
360 |
+
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
|
361 |
+
|
362 |
+
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
|
363 |
+
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
|
364 |
+
|
365 |
+
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
|
366 |
+
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
|
367 |
+
|
368 |
+
A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
|
369 |
+
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
|
370 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
|
371 |
+
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
|
372 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
|
373 |
+
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
|
374 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1)
|
375 |
+
|
376 |
+
B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
|
377 |
+
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
|
378 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
|
379 |
+
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
|
380 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
|
381 |
+
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
|
382 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1)
|
383 |
+
|
384 |
+
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0)
|
385 |
+
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0)
|
386 |
+
self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train()
|
387 |
+
|
388 |
+
if step % self.save_freq == 0:
|
389 |
+
self.save(os.path.join(self.result_dir, self.dataset, 'model'), step)
|
390 |
+
|
391 |
+
if step % 1000 == 0:
|
392 |
+
params = {}
|
393 |
+
|
394 |
+
if len(self.gpu_ids) > 1:
|
395 |
+
params['genA2B'] = self.genA2B.module.state_dict()
|
396 |
+
params['genB2A'] = self.genB2A.module.state_dict()
|
397 |
+
params['disGA'] = self.disGA.module.state_dict()
|
398 |
+
params['disGB'] = self.disGB.module.state_dict()
|
399 |
+
params['disLA'] = self.disLA.module.state_dict()
|
400 |
+
params['disLB'] = self.disLB.module.state_dict()
|
401 |
+
|
402 |
+
else:
|
403 |
+
params['genA2B'] = self.genA2B.state_dict()
|
404 |
+
params['genB2A'] = self.genB2A.state_dict()
|
405 |
+
params['disGA'] = self.disGA.state_dict()
|
406 |
+
params['disGB'] = self.disGB.state_dict()
|
407 |
+
params['disLA'] = self.disLA.state_dict()
|
408 |
+
params['disLB'] = self.disLB.state_dict()
|
409 |
+
torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt'))
|
410 |
+
|
411 |
+
def save(self, dir, step):
|
412 |
+
params = {}
|
413 |
+
|
414 |
+
if len(self.gpu_ids) > 1:
|
415 |
+
params['genA2B'] = self.genA2B.module.state_dict()
|
416 |
+
params['genB2A'] = self.genB2A.module.state_dict()
|
417 |
+
params['disGA'] = self.disGA.module.state_dict()
|
418 |
+
params['disGB'] = self.disGB.module.state_dict()
|
419 |
+
params['disLA'] = self.disLA.module.state_dict()
|
420 |
+
params['disLB'] = self.disLB.module.state_dict()
|
421 |
+
|
422 |
+
else:
|
423 |
+
params['genA2B'] = self.genA2B.state_dict()
|
424 |
+
params['genB2A'] = self.genB2A.state_dict()
|
425 |
+
params['disGA'] = self.disGA.state_dict()
|
426 |
+
params['disGB'] = self.disGB.state_dict()
|
427 |
+
params['disLA'] = self.disLA.state_dict()
|
428 |
+
params['disLB'] = self.disLB.state_dict()
|
429 |
+
torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
|
430 |
+
|
431 |
+
def load(self, dir, step):
|
432 |
+
params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
|
433 |
+
self.genA2B.load_state_dict(params['genA2B'])
|
434 |
+
self.genB2A.load_state_dict(params['genB2A'])
|
435 |
+
self.disGA.load_state_dict(params['disGA'])
|
436 |
+
self.disGB.load_state_dict(params['disGB'])
|
437 |
+
self.disLA.load_state_dict(params['disLA'])
|
438 |
+
self.disLB.load_state_dict(params['disLB'])
|
439 |
+
|
440 |
+
def test(self):
|
441 |
+
model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt'))
|
442 |
+
if not len(model_list) == 0:
|
443 |
+
model_list.sort()
|
444 |
+
iter = int(model_list[-1].split('_')[-1].split('.')[0])
|
445 |
+
self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter)
|
446 |
+
print(" [*] Load SUCCESS")
|
447 |
+
else:
|
448 |
+
print(" [*] Load FAILURE")
|
449 |
+
return
|
450 |
+
|
451 |
+
self.genA2B.eval(), self.genB2A.eval()
|
452 |
+
with torch.no_grad():
|
453 |
+
for n, (real_A, _) in enumerate(self.testA_loader):
|
454 |
+
real_A = real_A.to(self.device)
|
455 |
+
|
456 |
+
fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A)
|
457 |
+
|
458 |
+
fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B)
|
459 |
+
|
460 |
+
fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A)
|
461 |
+
|
462 |
+
A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))),
|
463 |
+
cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size),
|
464 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))),
|
465 |
+
cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size),
|
466 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))),
|
467 |
+
cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size),
|
468 |
+
RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)
|
469 |
+
|
470 |
+
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0)
|
471 |
+
|
472 |
+
for n, (real_B, _) in enumerate(self.testB_loader):
|
473 |
+
real_B = real_B.to(self.device)
|
474 |
+
|
475 |
+
fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B)
|
476 |
+
|
477 |
+
fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A)
|
478 |
+
|
479 |
+
fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B)
|
480 |
+
|
481 |
+
B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))),
|
482 |
+
cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size),
|
483 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))),
|
484 |
+
cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size),
|
485 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))),
|
486 |
+
cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size),
|
487 |
+
RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)
|
488 |
+
|
489 |
+
cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)
|
p2c/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .networks import ResnetGenerator
|
2 |
+
from .UGATIT_sadalin_hourglass import UgatitSadalinHourglass
|
3 |
+
|
p2c/models/face_features.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .mobilefacenet import MobileFaceNet
|
4 |
+
|
5 |
+
|
6 |
+
class FaceFeatures(object):
|
7 |
+
def __init__(self, weights_path, device):
|
8 |
+
self.device = device
|
9 |
+
self.model = MobileFaceNet(512).to(device)
|
10 |
+
self.model.load_state_dict(torch.load(weights_path))
|
11 |
+
self.model.eval()
|
12 |
+
|
13 |
+
def infer(self, batch_tensor):
|
14 |
+
# crop face
|
15 |
+
h, w = batch_tensor.shape[2:]
|
16 |
+
top = int(h / 2.1 * (0.8 - 0.33))
|
17 |
+
bottom = int(h - (h / 2.1 * 0.3))
|
18 |
+
size = bottom - top
|
19 |
+
left = int(w / 2 - size / 2)
|
20 |
+
right = left + size
|
21 |
+
batch_tensor = batch_tensor[:, :, top: bottom, left: right]
|
22 |
+
|
23 |
+
batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)
|
24 |
+
|
25 |
+
features = self.model(batch_tensor)
|
26 |
+
return features
|
27 |
+
|
28 |
+
def cosine_distance(self, batch_tensor1, batch_tensor2):
|
29 |
+
feature1 = self.infer(batch_tensor1)
|
30 |
+
feature2 = self.infer(batch_tensor2)
|
31 |
+
return 1 - torch.cosine_similarity(feature1, feature2)
|
p2c/models/mobilefacenet.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout, \
|
2 |
+
MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
3 |
+
import torch
|
4 |
+
from collections import namedtuple
|
5 |
+
|
6 |
+
|
7 |
+
################################## Original Arcface Model #############################################################
|
8 |
+
|
9 |
+
class Flatten(Module):
|
10 |
+
def forward(self, input):
|
11 |
+
return input.view(input.size(0), -1)
|
12 |
+
|
13 |
+
|
14 |
+
def l2_norm(input, axis=1):
|
15 |
+
norm = torch.norm(input, 2, axis, True)
|
16 |
+
output = torch.div(input, norm)
|
17 |
+
return output
|
18 |
+
|
19 |
+
|
20 |
+
class SEModule(Module):
|
21 |
+
def __init__(self, channels, reduction):
|
22 |
+
super(SEModule, self).__init__()
|
23 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
24 |
+
self.fc1 = Conv2d(
|
25 |
+
channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
26 |
+
self.relu = ReLU(inplace=True)
|
27 |
+
self.fc2 = Conv2d(
|
28 |
+
channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
29 |
+
self.sigmoid = Sigmoid()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
module_input = x
|
33 |
+
x = self.avg_pool(x)
|
34 |
+
x = self.fc1(x)
|
35 |
+
x = self.relu(x)
|
36 |
+
x = self.fc2(x)
|
37 |
+
x = self.sigmoid(x)
|
38 |
+
return module_input * x
|
39 |
+
|
40 |
+
|
41 |
+
class bottleneck_IR(Module):
|
42 |
+
def __init__(self, in_channel, depth, stride):
|
43 |
+
super(bottleneck_IR, self).__init__()
|
44 |
+
if in_channel == depth:
|
45 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
46 |
+
else:
|
47 |
+
self.shortcut_layer = Sequential(
|
48 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
|
49 |
+
self.res_layer = Sequential(
|
50 |
+
BatchNorm2d(in_channel),
|
51 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
52 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
shortcut = self.shortcut_layer(x)
|
56 |
+
res = self.res_layer(x)
|
57 |
+
return res + shortcut
|
58 |
+
|
59 |
+
|
60 |
+
class bottleneck_IR_SE(Module):
|
61 |
+
def __init__(self, in_channel, depth, stride):
|
62 |
+
super(bottleneck_IR_SE, self).__init__()
|
63 |
+
if in_channel == depth:
|
64 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
65 |
+
else:
|
66 |
+
self.shortcut_layer = Sequential(
|
67 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
68 |
+
BatchNorm2d(depth))
|
69 |
+
self.res_layer = Sequential(
|
70 |
+
BatchNorm2d(in_channel),
|
71 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
72 |
+
PReLU(depth),
|
73 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
74 |
+
BatchNorm2d(depth),
|
75 |
+
SEModule(depth, 16)
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
shortcut = self.shortcut_layer(x)
|
80 |
+
res = self.res_layer(x)
|
81 |
+
return res + shortcut
|
82 |
+
|
83 |
+
|
84 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
85 |
+
'''A named tuple describing a ResNet block.'''
|
86 |
+
|
87 |
+
|
88 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
89 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
90 |
+
|
91 |
+
|
92 |
+
def get_blocks(num_layers):
|
93 |
+
if num_layers == 50:
|
94 |
+
blocks = [
|
95 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
96 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
97 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
98 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
99 |
+
]
|
100 |
+
elif num_layers == 100:
|
101 |
+
blocks = [
|
102 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
103 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
104 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
105 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
106 |
+
]
|
107 |
+
elif num_layers == 152:
|
108 |
+
blocks = [
|
109 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
110 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
111 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
112 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
113 |
+
]
|
114 |
+
return blocks
|
115 |
+
|
116 |
+
|
117 |
+
class Backbone(Module):
|
118 |
+
def __init__(self, num_layers, drop_ratio, mode='ir'):
|
119 |
+
super(Backbone, self).__init__()
|
120 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
121 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
122 |
+
blocks = get_blocks(num_layers)
|
123 |
+
if mode == 'ir':
|
124 |
+
unit_module = bottleneck_IR
|
125 |
+
elif mode == 'ir_se':
|
126 |
+
unit_module = bottleneck_IR_SE
|
127 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
128 |
+
BatchNorm2d(64),
|
129 |
+
PReLU(64))
|
130 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
131 |
+
Dropout(drop_ratio),
|
132 |
+
Flatten(),
|
133 |
+
Linear(512 * 7 * 7, 512),
|
134 |
+
BatchNorm1d(512))
|
135 |
+
modules = []
|
136 |
+
for block in blocks:
|
137 |
+
for bottleneck in block:
|
138 |
+
modules.append(
|
139 |
+
unit_module(bottleneck.in_channel,
|
140 |
+
bottleneck.depth,
|
141 |
+
bottleneck.stride))
|
142 |
+
self.body = Sequential(*modules)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = self.input_layer(x)
|
146 |
+
x = self.body(x)
|
147 |
+
x = self.output_layer(x)
|
148 |
+
return l2_norm(x)
|
149 |
+
|
150 |
+
|
151 |
+
################################## MobileFaceNet #############################################################
|
152 |
+
|
153 |
+
class Conv_block(Module):
|
154 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
155 |
+
super(Conv_block, self).__init__()
|
156 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
|
157 |
+
bias=False)
|
158 |
+
self.bn = BatchNorm2d(out_c)
|
159 |
+
self.prelu = PReLU(out_c)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
x = self.conv(x)
|
163 |
+
x = self.bn(x)
|
164 |
+
x = self.prelu(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
class Linear_block(Module):
|
169 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
170 |
+
super(Linear_block, self).__init__()
|
171 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
|
172 |
+
bias=False)
|
173 |
+
self.bn = BatchNorm2d(out_c)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
x = self.conv(x)
|
177 |
+
x = self.bn(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class Depth_Wise(Module):
|
182 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
183 |
+
super(Depth_Wise, self).__init__()
|
184 |
+
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
185 |
+
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
|
186 |
+
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
187 |
+
self.residual = residual
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
if self.residual:
|
191 |
+
short_cut = x
|
192 |
+
x = self.conv(x)
|
193 |
+
x = self.conv_dw(x)
|
194 |
+
x = self.project(x)
|
195 |
+
if self.residual:
|
196 |
+
output = short_cut + x
|
197 |
+
else:
|
198 |
+
output = x
|
199 |
+
return output
|
200 |
+
|
201 |
+
|
202 |
+
class Residual(Module):
|
203 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
204 |
+
super(Residual, self).__init__()
|
205 |
+
modules = []
|
206 |
+
for _ in range(num_block):
|
207 |
+
modules.append(
|
208 |
+
Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
|
209 |
+
self.model = Sequential(*modules)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
return self.model(x)
|
213 |
+
|
214 |
+
|
215 |
+
class MobileFaceNet(Module):
|
216 |
+
def __init__(self, embedding_size):
|
217 |
+
super(MobileFaceNet, self).__init__()
|
218 |
+
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
|
219 |
+
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
220 |
+
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
|
221 |
+
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
222 |
+
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
|
223 |
+
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
224 |
+
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
|
225 |
+
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
226 |
+
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
227 |
+
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
|
228 |
+
self.conv_6_flatten = Flatten()
|
229 |
+
self.linear = Linear(512, embedding_size, bias=False)
|
230 |
+
self.bn = BatchNorm1d(embedding_size)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
out = self.conv1(x)
|
234 |
+
|
235 |
+
out = self.conv2_dw(out)
|
236 |
+
|
237 |
+
out = self.conv_23(out)
|
238 |
+
|
239 |
+
out = self.conv_3(out)
|
240 |
+
|
241 |
+
out = self.conv_34(out)
|
242 |
+
|
243 |
+
out = self.conv_4(out)
|
244 |
+
|
245 |
+
out = self.conv_45(out)
|
246 |
+
|
247 |
+
out = self.conv_5(out)
|
248 |
+
|
249 |
+
out = self.conv_6_sep(out)
|
250 |
+
|
251 |
+
out = self.conv_6_dw(out)
|
252 |
+
|
253 |
+
out = self.conv_6_flatten(out)
|
254 |
+
|
255 |
+
out = self.linear(out)
|
256 |
+
|
257 |
+
out = self.bn(out)
|
258 |
+
return l2_norm(out)
|
p2c/models/model_mobilefacenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f3bbd745247b32641724bf6d7964df7fd94ea5a098fe16d692b412fe44cd59b
|
3 |
+
size 4938364
|
p2c/models/networks.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
|
6 |
+
|
7 |
+
class ResnetGenerator(nn.Module):
|
8 |
+
def __init__(self, ngf=64, img_size=256, light=False):
|
9 |
+
super(ResnetGenerator, self).__init__()
|
10 |
+
self.light = light
|
11 |
+
|
12 |
+
self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),
|
13 |
+
nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),
|
14 |
+
nn.InstanceNorm2d(ngf),
|
15 |
+
nn.ReLU(True))
|
16 |
+
|
17 |
+
self.HourGlass1 = HourGlass(ngf, ngf)
|
18 |
+
self.HourGlass2 = HourGlass(ngf, ngf)
|
19 |
+
|
20 |
+
# Down-Sampling
|
21 |
+
self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),
|
22 |
+
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),
|
23 |
+
nn.InstanceNorm2d(ngf * 2),
|
24 |
+
nn.ReLU(True))
|
25 |
+
|
26 |
+
self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),
|
27 |
+
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),
|
28 |
+
nn.InstanceNorm2d(ngf*4),
|
29 |
+
nn.ReLU(True))
|
30 |
+
|
31 |
+
# Encoder Bottleneck
|
32 |
+
self.EncodeBlock1 = ResnetBlock(ngf*4)
|
33 |
+
self.EncodeBlock2 = ResnetBlock(ngf*4)
|
34 |
+
self.EncodeBlock3 = ResnetBlock(ngf*4)
|
35 |
+
self.EncodeBlock4 = ResnetBlock(ngf*4)
|
36 |
+
|
37 |
+
# Class Activation Map
|
38 |
+
self.gap_fc = nn.Linear(ngf*4, 1)
|
39 |
+
self.gmp_fc = nn.Linear(ngf*4, 1)
|
40 |
+
self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)
|
41 |
+
self.relu = nn.ReLU(True)
|
42 |
+
|
43 |
+
# Gamma, Beta block
|
44 |
+
if self.light:
|
45 |
+
self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),
|
46 |
+
nn.ReLU(True),
|
47 |
+
nn.Linear(ngf*4, ngf*4),
|
48 |
+
nn.ReLU(True))
|
49 |
+
else:
|
50 |
+
self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),
|
51 |
+
nn.ReLU(True),
|
52 |
+
nn.Linear(ngf*4, ngf*4),
|
53 |
+
nn.ReLU(True))
|
54 |
+
|
55 |
+
# Decoder Bottleneck
|
56 |
+
self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)
|
57 |
+
self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)
|
58 |
+
self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)
|
59 |
+
self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)
|
60 |
+
|
61 |
+
# Up-Sampling
|
62 |
+
self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),
|
63 |
+
nn.ReflectionPad2d(1),
|
64 |
+
nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),
|
65 |
+
LIN(ngf*2),
|
66 |
+
nn.ReLU(True))
|
67 |
+
|
68 |
+
self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),
|
69 |
+
nn.ReflectionPad2d(1),
|
70 |
+
nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),
|
71 |
+
LIN(ngf),
|
72 |
+
nn.ReLU(True))
|
73 |
+
|
74 |
+
self.HourGlass3 = HourGlass(ngf, ngf)
|
75 |
+
self.HourGlass4 = HourGlass(ngf, ngf, False)
|
76 |
+
|
77 |
+
self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),
|
78 |
+
nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),
|
79 |
+
nn.Tanh())
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.ConvBlock1(x)
|
83 |
+
x = self.HourGlass1(x)
|
84 |
+
x = self.HourGlass2(x)
|
85 |
+
|
86 |
+
x = self.DownBlock1(x)
|
87 |
+
x = self.DownBlock2(x)
|
88 |
+
|
89 |
+
x = self.EncodeBlock1(x)
|
90 |
+
content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
|
91 |
+
x = self.EncodeBlock2(x)
|
92 |
+
content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
|
93 |
+
x = self.EncodeBlock3(x)
|
94 |
+
content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
|
95 |
+
x = self.EncodeBlock4(x)
|
96 |
+
content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)
|
97 |
+
|
98 |
+
gap = F.adaptive_avg_pool2d(x, 1)
|
99 |
+
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
100 |
+
gap_weight = list(self.gap_fc.parameters())[0]
|
101 |
+
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
|
102 |
+
|
103 |
+
gmp = F.adaptive_max_pool2d(x, 1)
|
104 |
+
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
105 |
+
gmp_weight = list(self.gmp_fc.parameters())[0]
|
106 |
+
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
|
107 |
+
|
108 |
+
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
109 |
+
x = torch.cat([gap, gmp], 1)
|
110 |
+
x = self.relu(self.conv1x1(x))
|
111 |
+
|
112 |
+
heatmap = torch.sum(x, dim=1, keepdim=True)
|
113 |
+
|
114 |
+
if self.light:
|
115 |
+
x_ = F.adaptive_avg_pool2d(x, 1)
|
116 |
+
style_features = self.FC(x_.view(x_.shape[0], -1))
|
117 |
+
else:
|
118 |
+
style_features = self.FC(x.view(x.shape[0], -1))
|
119 |
+
|
120 |
+
x = self.DecodeBlock1(x, content_features4, style_features)
|
121 |
+
x = self.DecodeBlock2(x, content_features3, style_features)
|
122 |
+
x = self.DecodeBlock3(x, content_features2, style_features)
|
123 |
+
x = self.DecodeBlock4(x, content_features1, style_features)
|
124 |
+
|
125 |
+
x = self.UpBlock1(x)
|
126 |
+
x = self.UpBlock2(x)
|
127 |
+
|
128 |
+
x = self.HourGlass3(x)
|
129 |
+
x = self.HourGlass4(x)
|
130 |
+
out = self.ConvBlock2(x)
|
131 |
+
|
132 |
+
return out, cam_logit, heatmap
|
133 |
+
|
134 |
+
|
135 |
+
class ConvBlock(nn.Module):
|
136 |
+
def __init__(self, dim_in, dim_out):
|
137 |
+
super(ConvBlock, self).__init__()
|
138 |
+
self.dim_out = dim_out
|
139 |
+
|
140 |
+
self.ConvBlock1 = nn.Sequential(nn.InstanceNorm2d(dim_in),
|
141 |
+
nn.ReLU(True),
|
142 |
+
nn.ReflectionPad2d(1),
|
143 |
+
nn.Conv2d(dim_in, dim_out//2, kernel_size=3, stride=1, bias=False))
|
144 |
+
|
145 |
+
self.ConvBlock2 = nn.Sequential(nn.InstanceNorm2d(dim_out//2),
|
146 |
+
nn.ReLU(True),
|
147 |
+
nn.ReflectionPad2d(1),
|
148 |
+
nn.Conv2d(dim_out//2, dim_out//4, kernel_size=3, stride=1, bias=False))
|
149 |
+
|
150 |
+
self.ConvBlock3 = nn.Sequential(nn.InstanceNorm2d(dim_out//4),
|
151 |
+
nn.ReLU(True),
|
152 |
+
nn.ReflectionPad2d(1),
|
153 |
+
nn.Conv2d(dim_out//4, dim_out//4, kernel_size=3, stride=1, bias=False))
|
154 |
+
|
155 |
+
self.ConvBlock4 = nn.Sequential(nn.InstanceNorm2d(dim_in),
|
156 |
+
nn.ReLU(True),
|
157 |
+
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False))
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
residual = x
|
161 |
+
|
162 |
+
x1 = self.ConvBlock1(x)
|
163 |
+
x2 = self.ConvBlock2(x1)
|
164 |
+
x3 = self.ConvBlock3(x2)
|
165 |
+
out = torch.cat((x1, x2, x3), 1)
|
166 |
+
|
167 |
+
if residual.size(1) != self.dim_out:
|
168 |
+
residual = self.ConvBlock4(residual)
|
169 |
+
|
170 |
+
return residual + out
|
171 |
+
|
172 |
+
|
173 |
+
class HourGlass(nn.Module):
|
174 |
+
def __init__(self, dim_in, dim_out, use_res=True):
|
175 |
+
super(HourGlass, self).__init__()
|
176 |
+
self.use_res = use_res
|
177 |
+
|
178 |
+
self.HG = nn.Sequential(HourGlassBlock(dim_in, dim_out),
|
179 |
+
ConvBlock(dim_out, dim_out),
|
180 |
+
nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1, bias=False),
|
181 |
+
nn.InstanceNorm2d(dim_out),
|
182 |
+
nn.ReLU(True))
|
183 |
+
|
184 |
+
self.Conv1 = nn.Conv2d(dim_out, 3, kernel_size=1, stride=1)
|
185 |
+
|
186 |
+
if self.use_res:
|
187 |
+
self.Conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1)
|
188 |
+
self.Conv3 = nn.Conv2d(3, dim_out, kernel_size=1, stride=1)
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
ll = self.HG(x)
|
192 |
+
tmp_out = self.Conv1(ll)
|
193 |
+
|
194 |
+
if self.use_res:
|
195 |
+
ll = self.Conv2(ll)
|
196 |
+
tmp_out_ = self.Conv3(tmp_out)
|
197 |
+
return x + ll + tmp_out_
|
198 |
+
|
199 |
+
else:
|
200 |
+
return tmp_out
|
201 |
+
|
202 |
+
|
203 |
+
class HourGlassBlock(nn.Module):
|
204 |
+
def __init__(self, dim_in, dim_out):
|
205 |
+
super(HourGlassBlock, self).__init__()
|
206 |
+
|
207 |
+
self.ConvBlock1_1 = ConvBlock(dim_in, dim_out)
|
208 |
+
self.ConvBlock1_2 = ConvBlock(dim_out, dim_out)
|
209 |
+
self.ConvBlock2_1 = ConvBlock(dim_out, dim_out)
|
210 |
+
self.ConvBlock2_2 = ConvBlock(dim_out, dim_out)
|
211 |
+
self.ConvBlock3_1 = ConvBlock(dim_out, dim_out)
|
212 |
+
self.ConvBlock3_2 = ConvBlock(dim_out, dim_out)
|
213 |
+
self.ConvBlock4_1 = ConvBlock(dim_out, dim_out)
|
214 |
+
self.ConvBlock4_2 = ConvBlock(dim_out, dim_out)
|
215 |
+
|
216 |
+
self.ConvBlock5 = ConvBlock(dim_out, dim_out)
|
217 |
+
|
218 |
+
self.ConvBlock6 = ConvBlock(dim_out, dim_out)
|
219 |
+
self.ConvBlock7 = ConvBlock(dim_out, dim_out)
|
220 |
+
self.ConvBlock8 = ConvBlock(dim_out, dim_out)
|
221 |
+
self.ConvBlock9 = ConvBlock(dim_out, dim_out)
|
222 |
+
|
223 |
+
def forward(self, x):
|
224 |
+
skip1 = self.ConvBlock1_1(x)
|
225 |
+
down1 = F.avg_pool2d(x, 2)
|
226 |
+
down1 = self.ConvBlock1_2(down1)
|
227 |
+
|
228 |
+
skip2 = self.ConvBlock2_1(down1)
|
229 |
+
down2 = F.avg_pool2d(down1, 2)
|
230 |
+
down2 = self.ConvBlock2_2(down2)
|
231 |
+
|
232 |
+
skip3 = self.ConvBlock3_1(down2)
|
233 |
+
down3 = F.avg_pool2d(down2, 2)
|
234 |
+
down3 = self.ConvBlock3_2(down3)
|
235 |
+
|
236 |
+
skip4 = self.ConvBlock4_1(down3)
|
237 |
+
down4 = F.avg_pool2d(down3, 2)
|
238 |
+
down4 = self.ConvBlock4_2(down4)
|
239 |
+
|
240 |
+
center = self.ConvBlock5(down4)
|
241 |
+
|
242 |
+
up4 = self.ConvBlock6(center)
|
243 |
+
up4 = F.upsample(up4, scale_factor=2)
|
244 |
+
up4 = skip4 + up4
|
245 |
+
|
246 |
+
up3 = self.ConvBlock7(up4)
|
247 |
+
up3 = F.upsample(up3, scale_factor=2)
|
248 |
+
up3 = skip3 + up3
|
249 |
+
|
250 |
+
up2 = self.ConvBlock8(up3)
|
251 |
+
up2 = F.upsample(up2, scale_factor=2)
|
252 |
+
up2 = skip2 + up2
|
253 |
+
|
254 |
+
up1 = self.ConvBlock9(up2)
|
255 |
+
up1 = F.upsample(up1, scale_factor=2)
|
256 |
+
up1 = skip1 + up1
|
257 |
+
|
258 |
+
return up1
|
259 |
+
|
260 |
+
|
261 |
+
class ResnetBlock(nn.Module):
|
262 |
+
def __init__(self, dim, use_bias=False):
|
263 |
+
super(ResnetBlock, self).__init__()
|
264 |
+
conv_block = []
|
265 |
+
conv_block += [nn.ReflectionPad2d(1),
|
266 |
+
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
|
267 |
+
nn.InstanceNorm2d(dim),
|
268 |
+
nn.ReLU(True)]
|
269 |
+
|
270 |
+
conv_block += [nn.ReflectionPad2d(1),
|
271 |
+
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias),
|
272 |
+
nn.InstanceNorm2d(dim)]
|
273 |
+
|
274 |
+
self.conv_block = nn.Sequential(*conv_block)
|
275 |
+
|
276 |
+
def forward(self, x):
|
277 |
+
out = x + self.conv_block(x)
|
278 |
+
return out
|
279 |
+
|
280 |
+
|
281 |
+
class ResnetSoftAdaLINBlock(nn.Module):
|
282 |
+
def __init__(self, dim, use_bias=False):
|
283 |
+
super(ResnetSoftAdaLINBlock, self).__init__()
|
284 |
+
self.pad1 = nn.ReflectionPad2d(1)
|
285 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
|
286 |
+
self.norm1 = SoftAdaLIN(dim)
|
287 |
+
self.relu1 = nn.ReLU(True)
|
288 |
+
|
289 |
+
self.pad2 = nn.ReflectionPad2d(1)
|
290 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
|
291 |
+
self.norm2 = SoftAdaLIN(dim)
|
292 |
+
|
293 |
+
def forward(self, x, content_features, style_features):
|
294 |
+
out = self.pad1(x)
|
295 |
+
out = self.conv1(out)
|
296 |
+
out = self.norm1(out, content_features, style_features)
|
297 |
+
out = self.relu1(out)
|
298 |
+
|
299 |
+
out = self.pad2(out)
|
300 |
+
out = self.conv2(out)
|
301 |
+
out = self.norm2(out, content_features, style_features)
|
302 |
+
return out + x
|
303 |
+
|
304 |
+
|
305 |
+
class ResnetAdaLINBlock(nn.Module):
|
306 |
+
def __init__(self, dim, use_bias=False):
|
307 |
+
super(ResnetAdaLINBlock, self).__init__()
|
308 |
+
self.pad1 = nn.ReflectionPad2d(1)
|
309 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
|
310 |
+
self.norm1 = adaLIN(dim)
|
311 |
+
self.relu1 = nn.ReLU(True)
|
312 |
+
|
313 |
+
self.pad2 = nn.ReflectionPad2d(1)
|
314 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias)
|
315 |
+
self.norm2 = adaLIN(dim)
|
316 |
+
|
317 |
+
def forward(self, x, gamma, beta):
|
318 |
+
out = self.pad1(x)
|
319 |
+
out = self.conv1(out)
|
320 |
+
out = self.norm1(out, gamma, beta)
|
321 |
+
out = self.relu1(out)
|
322 |
+
out = self.pad2(out)
|
323 |
+
out = self.conv2(out)
|
324 |
+
out = self.norm2(out, gamma, beta)
|
325 |
+
|
326 |
+
return out + x
|
327 |
+
|
328 |
+
|
329 |
+
class SoftAdaLIN(nn.Module):
|
330 |
+
def __init__(self, num_features, eps=1e-5):
|
331 |
+
super(SoftAdaLIN, self).__init__()
|
332 |
+
self.norm = adaLIN(num_features, eps)
|
333 |
+
|
334 |
+
self.w_gamma = Parameter(torch.zeros(1, num_features))
|
335 |
+
self.w_beta = Parameter(torch.zeros(1, num_features))
|
336 |
+
|
337 |
+
self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features),
|
338 |
+
nn.ReLU(True),
|
339 |
+
nn.Linear(num_features, num_features))
|
340 |
+
self.c_beta = nn.Sequential(nn.Linear(num_features, num_features),
|
341 |
+
nn.ReLU(True),
|
342 |
+
nn.Linear(num_features, num_features))
|
343 |
+
self.s_gamma = nn.Linear(num_features, num_features)
|
344 |
+
self.s_beta = nn.Linear(num_features, num_features)
|
345 |
+
|
346 |
+
def forward(self, x, content_features, style_features):
|
347 |
+
content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features)
|
348 |
+
style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features)
|
349 |
+
|
350 |
+
w_gamma, w_beta = self.w_gamma.expand(x.shape[0], -1), self.w_beta.expand(x.shape[0], -1)
|
351 |
+
soft_gamma = (1. - w_gamma) * style_gamma + w_gamma * content_gamma
|
352 |
+
soft_beta = (1. - w_beta) * style_beta + w_beta * content_beta
|
353 |
+
|
354 |
+
out = self.norm(x, soft_gamma, soft_beta)
|
355 |
+
return out
|
356 |
+
|
357 |
+
|
358 |
+
class adaLIN(nn.Module):
|
359 |
+
def __init__(self, num_features, eps=1e-5):
|
360 |
+
super(adaLIN, self).__init__()
|
361 |
+
self.eps = eps
|
362 |
+
self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
|
363 |
+
self.rho.data.fill_(0.9)
|
364 |
+
|
365 |
+
def forward(self, input, gamma, beta):
|
366 |
+
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
|
367 |
+
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
|
368 |
+
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
|
369 |
+
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
|
370 |
+
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
|
371 |
+
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
372 |
+
|
373 |
+
return out
|
374 |
+
|
375 |
+
|
376 |
+
class LIN(nn.Module):
|
377 |
+
def __init__(self, num_features, eps=1e-5):
|
378 |
+
super(LIN, self).__init__()
|
379 |
+
self.eps = eps
|
380 |
+
self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
|
381 |
+
self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
|
382 |
+
self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
|
383 |
+
self.rho.data.fill_(0.0)
|
384 |
+
self.gamma.data.fill_(1.0)
|
385 |
+
self.beta.data.fill_(0.0)
|
386 |
+
|
387 |
+
def forward(self, input):
|
388 |
+
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
|
389 |
+
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
|
390 |
+
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True)
|
391 |
+
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
|
392 |
+
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
|
393 |
+
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
|
394 |
+
|
395 |
+
return out
|
396 |
+
|
397 |
+
|
398 |
+
class Discriminator(nn.Module):
|
399 |
+
def __init__(self, input_nc, ndf=64, n_layers=5):
|
400 |
+
super(Discriminator, self).__init__()
|
401 |
+
model = [nn.ReflectionPad2d(1),
|
402 |
+
nn.utils.spectral_norm(
|
403 |
+
nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
|
404 |
+
nn.LeakyReLU(0.2, True)]
|
405 |
+
|
406 |
+
for i in range(1, n_layers - 2):
|
407 |
+
mult = 2 ** (i - 1)
|
408 |
+
model += [nn.ReflectionPad2d(1),
|
409 |
+
nn.utils.spectral_norm(
|
410 |
+
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
|
411 |
+
nn.LeakyReLU(0.2, True)]
|
412 |
+
|
413 |
+
mult = 2 ** (n_layers - 2 - 1)
|
414 |
+
model += [nn.ReflectionPad2d(1),
|
415 |
+
nn.utils.spectral_norm(
|
416 |
+
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
|
417 |
+
nn.LeakyReLU(0.2, True)]
|
418 |
+
|
419 |
+
# Class Activation Map
|
420 |
+
mult = 2 ** (n_layers - 2)
|
421 |
+
self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
|
422 |
+
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
|
423 |
+
self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
|
424 |
+
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
425 |
+
|
426 |
+
self.pad = nn.ReflectionPad2d(1)
|
427 |
+
self.conv = nn.utils.spectral_norm(
|
428 |
+
nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))
|
429 |
+
|
430 |
+
self.model = nn.Sequential(*model)
|
431 |
+
|
432 |
+
def forward(self, input):
|
433 |
+
x = self.model(input)
|
434 |
+
|
435 |
+
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
436 |
+
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
437 |
+
gap_weight = list(self.gap_fc.parameters())[0]
|
438 |
+
gap = x * gap_weight.unsqueeze(2).unsqueeze(3)
|
439 |
+
|
440 |
+
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
441 |
+
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
442 |
+
gmp_weight = list(self.gmp_fc.parameters())[0]
|
443 |
+
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)
|
444 |
+
|
445 |
+
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
446 |
+
x = torch.cat([gap, gmp], 1)
|
447 |
+
x = self.leaky_relu(self.conv1x1(x))
|
448 |
+
|
449 |
+
heatmap = torch.sum(x, dim=1, keepdim=True)
|
450 |
+
|
451 |
+
x = self.pad(x)
|
452 |
+
out = self.conv(x)
|
453 |
+
|
454 |
+
return out, cam_logit, heatmap
|
455 |
+
|
456 |
+
|
457 |
+
class RhoClipper(object):
|
458 |
+
def __init__(self, min, max):
|
459 |
+
self.clip_min = min
|
460 |
+
self.clip_max = max
|
461 |
+
assert min < max
|
462 |
+
|
463 |
+
def __call__(self, module):
|
464 |
+
if hasattr(module, 'rho'):
|
465 |
+
w = module.rho.data
|
466 |
+
w = w.clamp(self.clip_min, self.clip_max)
|
467 |
+
module.rho.data = w
|
468 |
+
|
469 |
+
|
470 |
+
class WClipper(object):
|
471 |
+
def __init__(self, min, max):
|
472 |
+
self.clip_min = min
|
473 |
+
self.clip_max = max
|
474 |
+
assert min < max
|
475 |
+
|
476 |
+
def __call__(self, module):
|
477 |
+
if hasattr(module, 'w_gamma'):
|
478 |
+
w = module.w_gamma.data
|
479 |
+
w = w.clamp(self.clip_min, self.clip_max)
|
480 |
+
module.w_gamma.data = w
|
481 |
+
|
482 |
+
if hasattr(module, 'w_beta'):
|
483 |
+
w = module.w_beta.data
|
484 |
+
w = w.clamp(self.clip_min, self.clip_max)
|
485 |
+
module.w_beta.data = w
|
p2c/models/photo2cartoon_weights.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:542914cb8580cb733c7e914d22cc24ddabbbb207516d74ffc793f2a1b6c3eeb3
|
3 |
+
size 15290506
|
p2c/models/photo2cartoon_weights.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e08c84ea4c62251c6157dbf1d3ef44d2549d6aa8c9ee72ec9e4b3089ce5d5f0f
|
3 |
+
size 144306956
|
p2c/predict.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cog
|
2 |
+
import cv2
|
3 |
+
import tempfile
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
from utils import Preprocess
|
9 |
+
from models import ResnetGenerator
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor(cog.Predictor):
|
13 |
+
def setup(self):
|
14 |
+
pass
|
15 |
+
|
16 |
+
@cog.input("photo", type=Path, help="portrait photo (size < 1M)")
|
17 |
+
def predict(self, photo):
|
18 |
+
img = cv2.cvtColor(cv2.imread(str(photo)), cv2.COLOR_BGR2RGB)
|
19 |
+
out_path = gen_cartoon(img)
|
20 |
+
return out_path
|
21 |
+
|
22 |
+
|
23 |
+
def gen_cartoon(img):
|
24 |
+
pre = Preprocess()
|
25 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
26 |
+
net = ResnetGenerator(ngf=32, img_size=256, light=True).to(device)
|
27 |
+
|
28 |
+
assert os.path.exists(
|
29 |
+
'./models/photo2cartoon_weights.pt'), "[Step1: load weights] Can not find 'photo2cartoon_weights.pt' in folder 'models!!!'"
|
30 |
+
params = torch.load('./models/photo2cartoon_weights.pt', map_location=device)
|
31 |
+
net.load_state_dict(params['genA2B'])
|
32 |
+
|
33 |
+
# face alignment and segmentation
|
34 |
+
face_rgba = pre.process(img)
|
35 |
+
if face_rgba is None:
|
36 |
+
return None
|
37 |
+
|
38 |
+
face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
|
39 |
+
face = face_rgba[:, :, :3].copy()
|
40 |
+
mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
|
41 |
+
face = (face * mask + (1 - mask) * 255) / 127.5 - 1
|
42 |
+
|
43 |
+
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
|
44 |
+
face = torch.from_numpy(face).to(device)
|
45 |
+
|
46 |
+
# inference
|
47 |
+
with torch.no_grad():
|
48 |
+
cartoon = net(face)[0][0]
|
49 |
+
|
50 |
+
# post-process
|
51 |
+
cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
|
52 |
+
cartoon = (cartoon + 1) * 127.5
|
53 |
+
cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
|
54 |
+
cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
|
55 |
+
out_path = Path(tempfile.mkdtemp()) / "out.png"
|
56 |
+
cv2.imwrite(str(out_path), cartoon)
|
57 |
+
return out_path
|
p2c/test.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from models import ResnetGenerator
|
6 |
+
import argparse
|
7 |
+
from utils import Preprocess
|
8 |
+
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--photo_path', type=str, help='input photo path')
|
12 |
+
parser.add_argument('--save_path', type=str, help='cartoon save path')
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
|
16 |
+
|
17 |
+
class Photo2Cartoon:
|
18 |
+
def __init__(self):
|
19 |
+
self.pre = Preprocess()
|
20 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
+
self.net = ResnetGenerator(ngf=32, img_size=256, light=True).to(self.device)
|
22 |
+
|
23 |
+
assert os.path.exists('./models/photo2cartoon_weights.pt'), "[Step1: load weights] Can not find 'photo2cartoon_weights.pt' in folder 'models!!!'"
|
24 |
+
params = torch.load('./models/photo2cartoon_weights.pt', map_location=self.device)
|
25 |
+
self.net.load_state_dict(params['genA2B'])
|
26 |
+
print('[Step1: load weights] success!')
|
27 |
+
|
28 |
+
def inference(self, img):
|
29 |
+
# face alignment and segmentation
|
30 |
+
face_rgba = self.pre.process(img)
|
31 |
+
if face_rgba is None:
|
32 |
+
print('[Step2: face detect] can not detect face!!!')
|
33 |
+
return None
|
34 |
+
|
35 |
+
print('[Step2: face detect] success!')
|
36 |
+
face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
|
37 |
+
face = face_rgba[:, :, :3].copy()
|
38 |
+
mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
|
39 |
+
face = (face*mask + (1-mask)*255) / 127.5 - 1
|
40 |
+
|
41 |
+
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
|
42 |
+
face = torch.from_numpy(face).to(self.device)
|
43 |
+
|
44 |
+
# inference
|
45 |
+
with torch.no_grad():
|
46 |
+
cartoon = self.net(face)[0][0]
|
47 |
+
|
48 |
+
# post-process
|
49 |
+
cartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))
|
50 |
+
cartoon = (cartoon + 1) * 127.5
|
51 |
+
cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
|
52 |
+
cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
|
53 |
+
print('[Step3: photo to cartoon] success!')
|
54 |
+
return cartoon
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
img = cv2.cvtColor(cv2.imread(args.photo_path), cv2.COLOR_BGR2RGB)
|
59 |
+
c2p = Photo2Cartoon()
|
60 |
+
cartoon = c2p.inference(img)
|
61 |
+
if cartoon is not None:
|
62 |
+
cv2.imwrite(args.save_path, cartoon)
|
63 |
+
print('Cartoon portrait has been saved successfully!')
|
p2c/test_onnx.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
from utils import Preprocess
|
6 |
+
|
7 |
+
|
8 |
+
class Photo2Cartoon:
|
9 |
+
def __init__(self):
|
10 |
+
self.pre = Preprocess()
|
11 |
+
curPath = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
# assert os.path.exists('./models/photo2cartoon_weights.onnx'), "[Step1: load weights] Can not find 'photo2cartoon_weights.onnx' in folder 'models!!!'"
|
13 |
+
self.session = onnxruntime.InferenceSession(os.path.join(curPath, 'models/photo2cartoon_weights.onnx'))
|
14 |
+
print('[Step1: load weights] success!')
|
15 |
+
|
16 |
+
def inference(self, in_path):
|
17 |
+
img = cv2.cvtColor(cv2.imread(in_path), cv2.COLOR_BGR2RGB)
|
18 |
+
# face alignment and segmentation
|
19 |
+
face_rgba = self.pre.process(img)
|
20 |
+
if face_rgba is None:
|
21 |
+
print('[Step2: face detect] can not detect face!!!')
|
22 |
+
return None
|
23 |
+
|
24 |
+
print('[Step2: face detect] success!')
|
25 |
+
face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
|
26 |
+
face = face_rgba[:, :, :3].copy()
|
27 |
+
mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
|
28 |
+
face = (face * mask + (1 - mask) * 255) / 127.5 - 1
|
29 |
+
|
30 |
+
face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)
|
31 |
+
|
32 |
+
# inference
|
33 |
+
cartoon = self.session.run(['output'], input_feed={'input': face})
|
34 |
+
|
35 |
+
# post-process
|
36 |
+
cartoon = np.transpose(cartoon[0][0], (1, 2, 0))
|
37 |
+
cartoon = (cartoon + 1) * 127.5
|
38 |
+
cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
|
39 |
+
#cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)
|
40 |
+
|
41 |
+
print('[Step3: photo to cartoon] success!')
|
42 |
+
return cartoon
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
c2p = Photo2Cartoon()
|
47 |
+
cartoon = c2p.inference('')
|
48 |
+
if cartoon is not None:
|
49 |
+
print('Cartoon portrait has been saved successfully!')
|
p2c/train.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import UgatitSadalinHourglass
|
2 |
+
import argparse
|
3 |
+
import shutil
|
4 |
+
from utils import *
|
5 |
+
|
6 |
+
|
7 |
+
def parse_args():
|
8 |
+
"""parsing and configuration"""
|
9 |
+
desc = "photo2cartoon"
|
10 |
+
parser = argparse.ArgumentParser(description=desc)
|
11 |
+
parser.add_argument('--phase', type=str, default='train', help='[train / test]')
|
12 |
+
parser.add_argument('--light', type=str2bool, default=True, help='[U-GAT-IT full version / U-GAT-IT light version]')
|
13 |
+
parser.add_argument('--dataset', type=str, default='photo2cartoon', help='dataset name')
|
14 |
+
|
15 |
+
parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
|
16 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
|
17 |
+
parser.add_argument('--print_freq', type=int, default=1000, help='The number of image print freq')
|
18 |
+
parser.add_argument('--save_freq', type=int, default=1000, help='The number of model save freq')
|
19 |
+
parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
|
20 |
+
|
21 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
|
22 |
+
parser.add_argument('--adv_weight', type=int, default=1, help='Weight for GAN')
|
23 |
+
parser.add_argument('--cycle_weight', type=int, default=50, help='Weight for Cycle')
|
24 |
+
parser.add_argument('--identity_weight', type=int, default=10, help='Weight for Identity')
|
25 |
+
parser.add_argument('--cam_weight', type=int, default=1000, help='Weight for CAM')
|
26 |
+
parser.add_argument('--faceid_weight', type=int, default=1, help='Weight for Face ID')
|
27 |
+
|
28 |
+
parser.add_argument('--ch', type=int, default=32, help='base channel number per layer')
|
29 |
+
parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
|
30 |
+
|
31 |
+
parser.add_argument('--img_size', type=int, default=256, help='The size of image')
|
32 |
+
parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
|
33 |
+
|
34 |
+
# parser.add_argument('--device', type=str, default='cuda:0', help='Set gpu mode: [cpu, cuda]')
|
35 |
+
parser.add_argument('--gpu_ids', type=int, default=[0], nargs='+', help='Set [0, 1, 2, 3] for multi-gpu training')
|
36 |
+
parser.add_argument('--benchmark_flag', type=str2bool, default=False)
|
37 |
+
parser.add_argument('--resume', type=str2bool, default=False)
|
38 |
+
parser.add_argument('--rho_clipper', type=float, default=1.0)
|
39 |
+
parser.add_argument('--w_clipper', type=float, default=1.0)
|
40 |
+
parser.add_argument('--pretrained_weights', type=str, default='', help='pretrained weight path')
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
args.result_dir = './experiment/{}-size{}-ch{}-{}-lr{}-adv{}-cyc{}-id{}-identity{}-cam{}'.format(
|
44 |
+
os.path.basename(__file__)[:-3],
|
45 |
+
args.img_size,
|
46 |
+
args.ch,
|
47 |
+
args.light,
|
48 |
+
args.lr,
|
49 |
+
args.adv_weight,
|
50 |
+
args.cycle_weight,
|
51 |
+
args.faceid_weight,
|
52 |
+
args.identity_weight,
|
53 |
+
args.cam_weight)
|
54 |
+
|
55 |
+
return check_args(args)
|
56 |
+
|
57 |
+
|
58 |
+
def check_args(args):
|
59 |
+
check_folder(os.path.join(args.result_dir, args.dataset, 'model'))
|
60 |
+
check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
|
61 |
+
check_folder(os.path.join(args.result_dir, args.dataset, 'test'))
|
62 |
+
shutil.copy(__file__, args.result_dir)
|
63 |
+
return args
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
args = parse_args()
|
68 |
+
if args is None:
|
69 |
+
exit()
|
70 |
+
|
71 |
+
gan = UgatitSadalinHourglass(args)
|
72 |
+
gan.build_model()
|
73 |
+
|
74 |
+
if args.phase == 'train':
|
75 |
+
gan.train()
|
76 |
+
print(" [*] Training finished!")
|
77 |
+
|
78 |
+
if args.phase == 'test':
|
79 |
+
gan.test()
|
80 |
+
print(" [*] Test finished!")
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
main()
|
p2c/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .preprocess import Preprocess
|
2 |
+
from .utils import *
|
p2c/utils/face_detect.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import face_alignment
|
5 |
+
|
6 |
+
|
7 |
+
class FaceDetect:
|
8 |
+
def __init__(self, device, detector):
|
9 |
+
# landmarks will be detected by face_alignment library. Set device = 'cuda' if use GPU.
|
10 |
+
self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device, face_detector=detector)
|
11 |
+
|
12 |
+
def align(self, image):
|
13 |
+
landmarks = self.__get_max_face_landmarks(image)
|
14 |
+
|
15 |
+
if landmarks is None:
|
16 |
+
return None
|
17 |
+
|
18 |
+
else:
|
19 |
+
return self.__rotate(image, landmarks)
|
20 |
+
|
21 |
+
def __get_max_face_landmarks(self, image):
|
22 |
+
preds = self.fa.get_landmarks(image)
|
23 |
+
if preds is None:
|
24 |
+
return None
|
25 |
+
|
26 |
+
elif len(preds) == 1:
|
27 |
+
return preds[0]
|
28 |
+
|
29 |
+
else:
|
30 |
+
# find max face
|
31 |
+
areas = []
|
32 |
+
for pred in preds:
|
33 |
+
landmarks_top = np.min(pred[:, 1])
|
34 |
+
landmarks_bottom = np.max(pred[:, 1])
|
35 |
+
landmarks_left = np.min(pred[:, 0])
|
36 |
+
landmarks_right = np.max(pred[:, 0])
|
37 |
+
areas.append((landmarks_bottom - landmarks_top) * (landmarks_right - landmarks_left))
|
38 |
+
max_face_index = np.argmax(areas)
|
39 |
+
return preds[max_face_index]
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def __rotate(image, landmarks):
|
43 |
+
# rotation angle
|
44 |
+
left_eye_corner = landmarks[36]
|
45 |
+
right_eye_corner = landmarks[45]
|
46 |
+
radian = np.arctan((left_eye_corner[1] - right_eye_corner[1]) / (left_eye_corner[0] - right_eye_corner[0]))
|
47 |
+
|
48 |
+
# image size after rotating
|
49 |
+
height, width, _ = image.shape
|
50 |
+
cos = math.cos(radian)
|
51 |
+
sin = math.sin(radian)
|
52 |
+
new_w = int(width * abs(cos) + height * abs(sin))
|
53 |
+
new_h = int(width * abs(sin) + height * abs(cos))
|
54 |
+
|
55 |
+
# translation
|
56 |
+
Tx = new_w // 2 - width // 2
|
57 |
+
Ty = new_h // 2 - height // 2
|
58 |
+
|
59 |
+
# affine matrix
|
60 |
+
M = np.array([[cos, sin, (1 - cos) * width / 2. - sin * height / 2. + Tx],
|
61 |
+
[-sin, cos, sin * width / 2. + (1 - cos) * height / 2. + Ty]])
|
62 |
+
|
63 |
+
image_rotate = cv2.warpAffine(image, M, (new_w, new_h), borderValue=(255, 255, 255))
|
64 |
+
|
65 |
+
landmarks = np.concatenate([landmarks, np.ones((landmarks.shape[0], 1))], axis=1)
|
66 |
+
landmarks_rotate = np.dot(M, landmarks.T).T
|
67 |
+
return image_rotate, landmarks_rotate
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == '__main__':
|
71 |
+
img = cv2.cvtColor(cv2.imread('3989161_1.jpg'), cv2.COLOR_BGR2RGB)
|
72 |
+
fd = FaceDetect(device='cpu')
|
73 |
+
face_info = fd.align(img)
|
74 |
+
if face_info is not None:
|
75 |
+
image_align, landmarks_align = face_info
|
76 |
+
|
77 |
+
for i in range(landmarks_align.shape[0]):
|
78 |
+
cv2.circle(image_align, (int(landmarks_align[i][0]), int(landmarks_align[i][1])), 2, (255, 0, 0), -1)
|
79 |
+
|
80 |
+
cv2.imwrite('image_align.png', cv2.cvtColor(image_align, cv2.COLOR_RGB2BGR))
|
p2c/utils/face_seg.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.python.platform import gfile
|
6 |
+
|
7 |
+
|
8 |
+
curPath = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
|
10 |
+
|
11 |
+
class FaceSeg:
|
12 |
+
def __init__(self, model_path=os.path.join(curPath, 'seg_model_384.pb')):
|
13 |
+
config = tf.compat.v1.ConfigProto()
|
14 |
+
config.gpu_options.allow_growth = True
|
15 |
+
self._graph = tf.Graph()
|
16 |
+
self._sess = tf.compat.v1.Session(config=config, graph=self._graph)
|
17 |
+
|
18 |
+
self.pb_file_path = model_path
|
19 |
+
self._restore_from_pb()
|
20 |
+
self.input_op = self._sess.graph.get_tensor_by_name('input_1:0')
|
21 |
+
self.output_op = self._sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
|
22 |
+
|
23 |
+
def _restore_from_pb(self):
|
24 |
+
with self._sess.as_default():
|
25 |
+
with self._graph.as_default():
|
26 |
+
with gfile.FastGFile(self.pb_file_path, 'rb') as f:
|
27 |
+
graph_def = tf.compat.v1.GraphDef()
|
28 |
+
graph_def.ParseFromString(f.read())
|
29 |
+
tf.import_graph_def(graph_def, name='')
|
30 |
+
|
31 |
+
def input_transform(self, image):
|
32 |
+
image = cv2.resize(image, (384, 384), interpolation=cv2.INTER_AREA)
|
33 |
+
image_input = (image / 255.)[np.newaxis, :, :, :]
|
34 |
+
return image_input
|
35 |
+
|
36 |
+
def output_transform(self, output, shape):
|
37 |
+
output = cv2.resize(output, (shape[1], shape[0]))
|
38 |
+
image_output = (output * 255).astype(np.uint8)
|
39 |
+
return image_output
|
40 |
+
|
41 |
+
def get_mask(self, image):
|
42 |
+
image_input = self.input_transform(image)
|
43 |
+
output = self._sess.run(self.output_op, feed_dict={self.input_op: image_input})[0]
|
44 |
+
return self.output_transform(output, shape=image.shape[:2])
|
p2c/utils/preprocess.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .face_detect import FaceDetect
|
2 |
+
from .face_seg import FaceSeg
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class Preprocess:
|
7 |
+
def __init__(self, device='cpu', detector='dlib'):
|
8 |
+
self.detect = FaceDetect(device, detector) # device = 'cpu' or 'cuda', detector = 'dlib' or 'sfd'
|
9 |
+
self.segment = FaceSeg()
|
10 |
+
|
11 |
+
def process(self, image):
|
12 |
+
face_info = self.detect.align(image)
|
13 |
+
if face_info is None:
|
14 |
+
return None
|
15 |
+
image_align, landmarks_align = face_info
|
16 |
+
|
17 |
+
face = self.__crop(image_align, landmarks_align)
|
18 |
+
mask = self.segment.get_mask(face)
|
19 |
+
return np.dstack((face, mask))
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def __crop(image, landmarks):
|
23 |
+
landmarks_top = np.min(landmarks[:, 1])
|
24 |
+
landmarks_bottom = np.max(landmarks[:, 1])
|
25 |
+
landmarks_left = np.min(landmarks[:, 0])
|
26 |
+
landmarks_right = np.max(landmarks[:, 0])
|
27 |
+
|
28 |
+
# expand bbox
|
29 |
+
top = int(landmarks_top - 0.8 * (landmarks_bottom - landmarks_top))
|
30 |
+
bottom = int(landmarks_bottom + 0.3 * (landmarks_bottom - landmarks_top))
|
31 |
+
left = int(landmarks_left - 0.3 * (landmarks_right - landmarks_left))
|
32 |
+
right = int(landmarks_right + 0.3 * (landmarks_right - landmarks_left))
|
33 |
+
|
34 |
+
if bottom - top > right - left:
|
35 |
+
left -= ((bottom - top) - (right - left)) // 2
|
36 |
+
right = left + (bottom - top)
|
37 |
+
else:
|
38 |
+
top -= ((right - left) - (bottom - top)) // 2
|
39 |
+
bottom = top + (right - left)
|
40 |
+
|
41 |
+
image_crop = np.ones((bottom - top + 1, right - left + 1, 3), np.uint8) * 255
|
42 |
+
|
43 |
+
h, w = image.shape[:2]
|
44 |
+
left_white = max(0, -left)
|
45 |
+
left = max(0, left)
|
46 |
+
right = min(right, w-1)
|
47 |
+
right_white = left_white + (right-left)
|
48 |
+
top_white = max(0, -top)
|
49 |
+
top = max(0, top)
|
50 |
+
bottom = min(bottom, h-1)
|
51 |
+
bottom_white = top_white + (bottom - top)
|
52 |
+
|
53 |
+
image_crop[top_white:bottom_white+1, left_white:right_white+1] = image[top:bottom+1, left:right+1].copy()
|
54 |
+
return image_crop
|
p2c/utils/seg_model_384.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66a04bc2032b54013d2ae994b34d22518144276f1cbdd2d8cbb1a4a28f50285f
|
3 |
+
size 32477258
|
p2c/utils/utils.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from scipy import misc
|
6 |
+
|
7 |
+
|
8 |
+
def load_test_data(image_path, size=256):
|
9 |
+
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
10 |
+
if img is None:
|
11 |
+
return None
|
12 |
+
|
13 |
+
h, w, c = img.shape
|
14 |
+
if img.shape[2] == 4:
|
15 |
+
white = np.ones((h, w, 3), np.uint8) * 255
|
16 |
+
img_rgb = img[:, :, :3].copy()
|
17 |
+
mask = img[:, :, 3].copy()
|
18 |
+
mask = (mask / 255).astype(np.uint8)
|
19 |
+
img = (img_rgb * mask[:, :, np.newaxis]).astype(np.uint8) + white * (1 - mask[:, :, np.newaxis])
|
20 |
+
|
21 |
+
img = cv2.resize(img, (size, size), cv2.INTER_AREA)
|
22 |
+
img = RGB2BGR(img)
|
23 |
+
|
24 |
+
img = np.expand_dims(img, axis=0)
|
25 |
+
img = preprocessing(img)
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
def preprocessing(x):
|
30 |
+
x = x/127.5 - 1
|
31 |
+
# -1 ~ 1
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
def save_images(images, size, image_path):
|
36 |
+
return imsave(inverse_transform(images), size, image_path)
|
37 |
+
|
38 |
+
|
39 |
+
def inverse_transform(images):
|
40 |
+
return (images+1.) / 2
|
41 |
+
|
42 |
+
|
43 |
+
def imsave(images, size, path):
|
44 |
+
return misc.imsave(path, merge(images, size))
|
45 |
+
|
46 |
+
|
47 |
+
def merge(images, size):
|
48 |
+
h, w = images.shape[1], images.shape[2]
|
49 |
+
img = np.zeros((h * size[0], w * size[1], 3))
|
50 |
+
for idx, image in enumerate(images):
|
51 |
+
i = idx % size[1]
|
52 |
+
j = idx // size[1]
|
53 |
+
img[h*j:h*(j+1), w*i:w*(i+1), :] = image
|
54 |
+
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
def check_folder(log_dir):
|
59 |
+
if not os.path.exists(log_dir):
|
60 |
+
os.makedirs(log_dir)
|
61 |
+
return log_dir
|
62 |
+
|
63 |
+
|
64 |
+
def str2bool(x):
|
65 |
+
return x.lower() in ('true')
|
66 |
+
|
67 |
+
|
68 |
+
def cam(x, size=256):
|
69 |
+
x = x - np.min(x)
|
70 |
+
cam_img = x / np.max(x)
|
71 |
+
cam_img = np.uint8(255 * cam_img)
|
72 |
+
cam_img = cv2.resize(cam_img, (size, size))
|
73 |
+
cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET)
|
74 |
+
return cam_img / 255.0
|
75 |
+
|
76 |
+
|
77 |
+
def imagenet_norm(x):
|
78 |
+
mean = [0.485, 0.456, 0.406]
|
79 |
+
std = [0.299, 0.224, 0.225]
|
80 |
+
mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
|
81 |
+
std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
|
82 |
+
return (x - mean) / std
|
83 |
+
|
84 |
+
|
85 |
+
def denorm(x):
|
86 |
+
return x * 0.5 + 0.5
|
87 |
+
|
88 |
+
|
89 |
+
def tensor2numpy(x):
|
90 |
+
return x.detach().cpu().numpy().transpose(1, 2, 0)
|
91 |
+
|
92 |
+
|
93 |
+
def RGB2BGR(x):
|
94 |
+
return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
cmake
|
2 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python-headless==4.5.5.62
|
2 |
+
Pillow==9.0.1
|
3 |
+
scipy==1.7.3
|
4 |
+
tensorflow-gpu==1.14.0
|
5 |
+
scikit-image==0.14.5
|
6 |
+
onnxruntime
|
7 |
+
face-alignment
|
8 |
+
dlib
|
9 |
+
|