jgwill commited on
Commit
6321685
1 Parent(s): 16e9c21
evaluation/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Style transfer quaintitative evaluation using Deception Score
2
+
3
+ ### How to calculate Deception Score:
4
+
5
+ 1. Run `./download_evaluation_data.py` to download the weights for artist classification model.
6
+ 2. Set `results_dir` variable in `eval_deception_score.py:92` to point to the directory with stylized images.
7
+ All images generated by one method must be in one directory.
8
+ Image filenames must be in format `"{content_name}_stylized_{artist_name}.jpg"`, for example: `"Places366_val_00000510_stylized_vincent-van-gogh.jpg"`.
9
+ 3. Run `./run_deception_score_vgg_16_wikiart.sh`
10
+ 4. Read results in the log file in `./logs` directory.
11
+
12
+
13
+ ### How to evaluate your own model:
14
+
15
+ - Download validation sets from MSCOCO ([val2017.zip](http://images.cocodataset.org/zips/val2017.zip)) and Places365 ([val_large.tar](http://data.csail.mit.edu/places/places365/val_large.tar)).
16
+ - To compare with deception score reported in the paper run your stylization model on the content images listed in [eval_paths_700_val.json](evaluation_data/eval_paths_700_val.json).
evaluation/__init__.py ADDED
File without changes
evaluation/check_fc8_labels.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
2
+ #
3
+ # This file is part of Adaptive Style Transfer
4
+ #
5
+ # Adaptive Style Transfer is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # Adaptive Style Transfer is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ import pandas as pd
19
+ import h5py
20
+ import numpy as np
21
+
22
+ ARTISTS = ['claude-monet',
23
+ 'paul-cezanne',
24
+ 'el-greco',
25
+ 'paul-gauguin',
26
+ 'samuel-peploe',
27
+ 'vincent-van-gogh',
28
+ 'edvard-munch',
29
+ 'pablo-picasso',
30
+ 'berthe-morisot',
31
+ 'ernst-ludwig-kirchner',
32
+ 'jackson-pollock',
33
+ 'wassily-kandinsky',
34
+ 'nicholas-roerich']
35
+
36
+
37
+ def get_artist_labels_wikiart(artists=ARTISTS):
38
+ """
39
+ Get mapping of artist name to class label
40
+ """
41
+ split_df = pd.read_hdf('evaluation_data/split.hdf5')
42
+
43
+ labels = dict()
44
+
45
+ for artist_id in artists:
46
+ artist_id_in_split = artist_id
47
+ print artist_id
48
+ cur_df = split_df[split_df.index.str.startswith(artist_id_in_split)]
49
+ assert len(cur_df)
50
+ if not np.all(cur_df.index.str.startswith(artist_id_in_split + '_')):
51
+ print cur_df[~cur_df.index.str.startswith(artist_id_in_split + '_')]
52
+ assert False
53
+
54
+ print '===='
55
+ labels[artist_id] = cur_df['label'][0]
56
+ return labels
57
+
58
+
59
+ if __name__ == '__main__':
60
+
61
+ print get_artist_labels_wikiart(ARTISTS)
evaluation/download_evaluation_data.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from __future__ import print_function
3
+
4
+ import requests
5
+ import os
6
+
7
+ from torchvision.datasets.utils import download_url
8
+
9
+ API_ENDPOINT = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key={}'
10
+
11
+ EVALUATION_DATA_URL = 'https://yadi.sk/d/A2CBqSGuJ0M_XA'
12
+
13
+
14
+ def get_real_direct_link(sharing_link):
15
+ pk_request = requests.get(API_ENDPOINT.format(sharing_link))
16
+
17
+ return pk_request.json()['href']
18
+
19
+
20
+ def unzip(path, target_dir='.'):
21
+ import zipfile
22
+ with zipfile.ZipFile(path, 'r') as zip_ref:
23
+ zip_ref.extractall(target_dir)
24
+
25
+
26
+ def main():
27
+ root = "."
28
+ link = get_real_direct_link(EVALUATION_DATA_URL)
29
+ filename = 'evaluation_data.zip'
30
+ print('Downloadng data (1Gb). This may take a while...')
31
+ download_url(link, root, filename, None)
32
+ print('Unzipping...')
33
+ unzip(os.path.join(root, filename), target_dir='.')
34
+ print('Done.')
35
+
36
+
37
+ if __name__ == '__main__':
38
+ main()
39
+
evaluation/eval_deception_score.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
2
+ #
3
+ # This file is part of Adaptive Style Transfer
4
+ #
5
+ # Adaptive Style Transfer is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # Adaptive Style Transfer is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ from pprint import pformat
22
+ import glob
23
+ import numpy as np
24
+ import pandas as pd
25
+ import re
26
+
27
+ from feature_extractor.feature_extractor import SlimFeatureExtractor
28
+ from logger import Logger
29
+ from check_fc8_labels import get_artist_labels_wikiart
30
+
31
+
32
+ def parse_one_or_list(str_value):
33
+ if str_value is not None:
34
+ if str_value.lower() == 'none':
35
+ str_value = None
36
+ elif ',' in str_value:
37
+ str_value = str_value.split(',')
38
+ return str_value
39
+
40
+
41
+ def parse_list(str_value):
42
+ if ',' in str_value:
43
+ str_value = str_value.split(',')
44
+ else:
45
+ str_value = [str_value]
46
+ return str_value
47
+
48
+
49
+ def parse_none(str_value):
50
+ if str_value is not None:
51
+ if str_value.lower() == 'none' or str_value == "":
52
+ str_value = None
53
+ return str_value
54
+
55
+
56
+ def parse_args(argv):
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument('-net', '--net', help='network type',
59
+ choices=['vgg_16', 'vgg_16_multihead'], default='vgg_16')
60
+ parser.add_argument('-log', '--log-path', help='log path', type=str,
61
+ default='/tmp/res.txt'
62
+ )
63
+ parser.add_argument('-s', '--snapshot_path', type=str,
64
+ default='vgg_16.ckpt')
65
+ parser.add_argument('-b', '--batch-size', type=int, default=64)
66
+ parser.add_argument('--method', type=str, default='ours')
67
+ parser.add_argument('--num_classes', type=int, default=624)
68
+ parser.add_argument('--dataset', type=str, default='wikiart', choices=['wikiart'])
69
+ args = parser.parse_args(argv)
70
+ args = vars(args)
71
+ return args
72
+
73
+
74
+ def create_slim_extractor(cli_params):
75
+ extractor_class = SlimFeatureExtractor
76
+ extractor_ = extractor_class(cli_params['net'], cli_params['snapshot_path'],
77
+ should_restore_classifier=True,
78
+ gpu_memory_fraction=0.95,
79
+ vgg_16_heads=None if cli_params['net'] != 'vgg_16_multihead' else {'artist_id': cli_params['num_classes']})
80
+ return extractor_
81
+
82
+
83
+ classification_layer = {
84
+ 'vgg_16': 'vgg_16/fc8',
85
+ 'vgg_16_multihead': 'vgg_16/fc8_artist_id'
86
+ }
87
+
88
+
89
+ def run(extractor, classification_layer, images_df, batch_size=64, logger=Logger()):
90
+ images_df = images_df.copy()
91
+ if len(images_df) == 0:
92
+ print 'No images found!'
93
+ return -1, 0, 0
94
+ probs = extractor.extract(images_df['image_path'].values, [classification_layer],
95
+ verbose=1, batch_size=batch_size)
96
+ images_df['predicted_class'] = np.argmax(probs, axis=1).tolist()
97
+ is_correct = images_df['label'] == images_df['predicted_class']
98
+ accuracy = float(is_correct.sum()) / len(images_df)
99
+
100
+ logger.log('Num images: {}'.format(len(images_df)))
101
+ logger.log('Correctly classified: {}/{}'.format(is_correct.sum(), len(images_df)))
102
+ logger.log('Accuracy: {:.5f}'.format(accuracy))
103
+ logger.log('\n===')
104
+ return accuracy, is_correct.sum(), len(images_df)
105
+
106
+
107
+ # image filenames must be in format "{content_name}_stylized_{artist_name}.jpg"
108
+ # uncomment methods which you want to evaluate and set the paths to the folders with the stylized images
109
+ results_dir = {
110
+ 'ours': 'path/to/our/stylizations',
111
+ # 'gatys': 'path/to/gatys_stylizations',
112
+ # 'cyclegan': '',
113
+ # 'adain': '',
114
+ # 'johnson': '',
115
+ # 'wct': '',
116
+ # 'real_wiki_test': os.path.expanduser('~/workspace/wikiart/images_square_227x227') # uncomment to test on real images from wikiart test set
117
+ }
118
+
119
+
120
+ style_2_image_name = {u'berthe-morisot': u'Morisot-1886-the-lesson-in-the-garden',
121
+ u'claude-monet': u'monet-1914-water-lilies-37.jpg!HD',
122
+ u'edvard-munch': u'Munch-the-scream-1893',
123
+ u'el-greco': u'el-greco-the-resurrection-1595.jpg!HD',
124
+ u'ernst-ludwig-kirchner': u'Kirchner-1913-street-berlin.jpg!HD',
125
+ u'jackson-pollock': u'Pollock-number-one-moma-November-31-1950-1950',
126
+ u'nicholas-roerich': u'nicholas-roerich_mongolia-campaign-of-genghis-khan',
127
+ u'pablo-picasso': u'weeping-woman-1937',
128
+ u'paul-cezanne': u'still-life-with-apples-1894.jpg!HD',
129
+ u'paul-gauguin': u'Gauguin-the-seed-of-the-areoi-1892',
130
+ u'samuel-peploe': u'peploe-ile-de-brehat-1911-1',
131
+ u'vincent-van-gogh': u'vincent-van-gogh_road-with-cypresses-1890',
132
+ u'wassily-kandinsky': u'Kandinsky-improvisation-28-second-version-1912'}
133
+
134
+
135
+ artist_2_label_wikiart = get_artist_labels_wikiart()
136
+
137
+
138
+ def get_images_df(dataset, method, artist_slug):
139
+ images_dir = results_dir[method]
140
+ paths = glob.glob(os.path.join(images_dir, '*.jpg')) + glob.glob(os.path.join(images_dir, '*.png'))
141
+ # print paths
142
+ assert len(paths) or method.startswith('real')
143
+
144
+ if not method.startswith('real'):
145
+ cur_style_paths = [x for x in paths if re.match('.*_stylized_({}|{}).(jpg|png)'.format(artist_slug, style_2_image_name[artist_slug]), os.path.basename(x)) is not None]
146
+ elif method == 'real_wiki_test':
147
+ # use only images from the test set
148
+ split_df = pd.read_hdf(os.path.expanduser('evaluation_data/split.hdf5'))
149
+ split_df['image_id'] = split_df.index
150
+ df = split_df[split_df['split'] == 'test']
151
+ df['artist_id'] = df['image_id'].apply(lambda x: x.split('_', 1)[0])
152
+ df['image_path'] = df['image_id'].apply(lambda x: os.path.join(results_dir['real_wiki_test'], x + '.png'))
153
+ cur_style_paths = df.loc[df['artist_id'] == artist_slug, 'image_path'].values
154
+
155
+ df = pd.DataFrame(index=[os.path.basename(x).split('_stylized_', 1)[0].rstrip('.') for x in
156
+ cur_style_paths], data={'image_path': cur_style_paths, 'artist': artist_slug})
157
+
158
+ df['label'] = artist_2_label_wikiart[artist_slug]
159
+ return df
160
+
161
+
162
+ def sprint_stats(stats):
163
+ msg = ''
164
+ msg += 'artist\t accuracy\t is_correct\t total\n'
165
+ for key in sorted(stats.keys()):
166
+ msg += key + '\t {:.5f}\t {}\t \t{}\n'.format(*stats[key])
167
+ return msg
168
+
169
+
170
+ if __name__ == '__main__':
171
+ import sys
172
+
173
+ args = parse_args(sys.argv[1:])
174
+
175
+ if not os.path.exists(os.path.dirname(args['log_path'])):
176
+ os.makedirs(os.path.dirname(args['log_path']))
177
+ logger = Logger(args['log_path'])
178
+ print 'Snapshot: {}'.format(args['snapshot_path'])
179
+ extractor = create_slim_extractor(args)
180
+ classification_layer = classification_layer[args['net']]
181
+
182
+ stats = dict()
183
+ assert artist_2_label_wikiart is not None
184
+ for artist in artist_2_label_wikiart.keys():
185
+ print('Method:', args['method'])
186
+ logger.log('Artist: {}'.format(artist))
187
+ images_df = get_images_df(dataset=args['dataset'], method=args['method'], artist_slug=artist)
188
+ acc, num_is_correct, num_total = run(extractor, classification_layer, images_df,
189
+ batch_size=args['batch_size'], logger=logger)
190
+ stats[artist] = (acc, num_is_correct, num_total)
191
+
192
+ logger.log('{}'.format(pformat(args)))
193
+ print 'Images dir:', results_dir[args['method']]
194
+ logger.log('===\n\n')
195
+ logger.log(args['method'])
196
+ logger.log('{}'.format(sprint_stats(stats)))
evaluation/logger.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
2
+ #
3
+ # This file is part of Adaptive Style Transfer
4
+ #
5
+ # Adaptive Style Transfer is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # Adaptive Style Transfer is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ import sys
19
+ import os
20
+ import subprocess
21
+
22
+
23
+ class Logger(object):
24
+ def __init__(self, filepath=None, mode='w'):
25
+ self.file = None
26
+ self.filepath = filepath
27
+ if filepath is not None:
28
+ self.file = open(filepath, mode=mode, buffering=0)
29
+
30
+ def __enter__(self):
31
+ return self
32
+
33
+ def log(self, msg, should_print=True):
34
+ if should_print:
35
+ print '[LOG] {}'.format(msg)
36
+ if self.file is not None:
37
+ self.file.write('{}\n'.format(msg))
38
+
39
+ def write(self, msg):
40
+ sys.__stdout__.write(msg)
41
+ if self.file is not None:
42
+ self.file.write(msg)
43
+ self.file.flush()
44
+
45
+ def close(self):
46
+ if self.file is not None:
47
+ self.file.close()
48
+
49
+ def __exit__(self, exc_type, exc_val, exc_tb):
50
+ self.close()
51
+
52
+
53
+ def log(logger, msg, should_print=True):
54
+ if logger:
55
+ logger.log(msg, should_print)
56
+ else:
57
+ if should_print:
58
+ print msg
59
+
60
+
61
+ class Tee:
62
+ def __init__(self, log_path):
63
+ self.prev_stdout_descriptor = os.dup(sys.stdout.fileno())
64
+ self.prev_stderr_descriptor = os.dup(sys.stderr.fileno())
65
+
66
+ tee = subprocess.Popen(['tee', log_path], stdin=subprocess.PIPE)
67
+ os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
68
+ os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
69
+
70
+ def close(self):
71
+ os.dup2(self.prev_stdout_descriptor, sys.stdout.fileno())
72
+ os.close(self.prev_stdout_descriptor)
73
+ os.dup2(self.prev_stderr_descriptor, sys.stderr.fileno())
74
+ os.close(self.prev_stderr_descriptor)
75
+
76
+ def __enter__(self):
77
+ return self
78
+
79
+ def __exit__(self, exc_type, exc_val, exc_tb):
80
+ self.close()
evaluation/run_deception_score_vgg_16_wikiart.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Copyright (C) 2018 Artsiom Sanakoyeu and Dmytro Kotovenko
4
+ #
5
+ # This file is part of Adaptive Style Transfer
6
+ #
7
+ # Adaptive Style Transfer is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU General Public License as published by
9
+ # the Free Software Foundation, either version 3 of the License, or
10
+ # (at your option) any later version.
11
+ #
12
+ # Adaptive Style Transfer is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU General Public License
18
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
19
+
20
+ set -e
21
+ LOG_DIR=logs
22
+ mkdir -p ${LOG_DIR}
23
+
24
+ NET=vgg_16_multihead
25
+
26
+ METHODS=( "ours" )
27
+ #METHODS=( "ours" "real_wiki_test" )
28
+
29
+ for method in ${METHODS[@]}
30
+ do
31
+ echo $method
32
+ python eval_deception_score.py \
33
+ -net=${NET} \
34
+ -s="evaluation_data/model.ckpt-790000" \
35
+ -log=${LOG_DIR}/deception_score_${method}.txt \
36
+ --method=$method \
37
+ --num_classes=624 \
38
+ --dataset="wikiart"
39
+ done