elias3446 commited on
Commit
0f888c1
β€’
1 Parent(s): a087ce1

Upload 9 files

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (119 Bytes). View file
 
utils/__pycache__/language_utils.cpython-38.pyc ADDED
Binary file (5.73 kB). View file
 
utils/__pycache__/options.cpython-38.pyc ADDED
Binary file (3.94 kB). View file
 
utils/__pycache__/util.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
 
utils/language_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from curses import A_ATTRIBUTES
2
+
3
+ import numpy
4
+ import torch
5
+ from pip import main
6
+ from sentence_transformers import SentenceTransformer, util
7
+
8
+ # predefined shape text
9
+ upper_length_text = [
10
+ 'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
11
+ 'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
12
+ 'with short sleeves', 'medium-sleeve', 'medium sleeves',
13
+ 'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
14
+ 'long sleeves', 'with long sleeves'
15
+ ]
16
+ upper_length_attr = {
17
+ 'sleeveless': 0,
18
+ 'without sleeves': 0,
19
+ 'sleeves have been cut off': 0,
20
+ 'tank top': 0,
21
+ 'tank shirt': 0,
22
+ 'muscle shirt': 0,
23
+ 'short-sleeve': 1,
24
+ 'with short sleeves': 1,
25
+ 'short sleeves': 1,
26
+ 'medium-sleeve': 2,
27
+ 'with medium sleeves': 2,
28
+ 'medium sleeves': 2,
29
+ 'sleeves reach elbow': 2,
30
+ 'long-sleeve': 3,
31
+ 'long sleeves': 3,
32
+ 'with long sleeves': 3
33
+ }
34
+ lower_length_text = [
35
+ 'three-point', 'medium', 'short', 'covering knee', 'cropped',
36
+ 'three-quarter', 'long', 'slack', 'of long length'
37
+ ]
38
+ lower_length_attr = {
39
+ 'three-point': 0,
40
+ 'medium': 1,
41
+ 'covering knee': 1,
42
+ 'short': 1,
43
+ 'cropped': 2,
44
+ 'three-quarter': 2,
45
+ 'long': 3,
46
+ 'slack': 3,
47
+ 'of long length': 3
48
+ }
49
+ socks_length_text = [
50
+ 'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
51
+ ]
52
+ socks_length_attr = {
53
+ 'socks': 0,
54
+ 'stocking': 1,
55
+ 'pantyhose': 1,
56
+ 'leggings': 1,
57
+ 'sheer hosiery': 1
58
+ }
59
+ hat_text = ['hat', 'cap', 'chapeau']
60
+ eyeglasses_text = ['sunglasses']
61
+ belt_text = ['belt', 'with a dress tied around the waist']
62
+ outer_shape_text = [
63
+ 'with outer clothing open', 'with outer clothing unzipped',
64
+ 'covering inner clothes', 'with outer clothing zipped'
65
+ ]
66
+ outer_shape_attr = {
67
+ 'with outer clothing open': 0,
68
+ 'with outer clothing unzipped': 0,
69
+ 'covering inner clothes': 1,
70
+ 'with outer clothing zipped': 1
71
+ }
72
+
73
+ upper_types = [
74
+ 'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
75
+ ]
76
+ outer_types = [
77
+ 'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
78
+ 'duffle', 'cardigan'
79
+ ]
80
+ skirt_types = ['skirt']
81
+ dress_types = ['dress']
82
+ pant_types = ['jeans', 'pants', 'trousers']
83
+ rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
84
+
85
+ attr_names_list = [
86
+ 'gender', 'hair length', '0 upper clothing length',
87
+ '1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
88
+ '6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
89
+ '9 skirt', '10 dress', '11 pants', '12 rompers'
90
+ ]
91
+
92
+
93
+ def generate_shape_attributes(user_shape_texts):
94
+ model = SentenceTransformer('all-MiniLM-L6-v2')
95
+ parsed_texts = user_shape_texts.split(',')
96
+
97
+ text_num = len(parsed_texts)
98
+
99
+ human_attr = [0, 0]
100
+ attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
101
+
102
+ changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
103
+ for text_id, text in enumerate(parsed_texts):
104
+ user_embeddings = model.encode(text)
105
+ if ('man' in text) and (text_id == 0):
106
+ human_attr[0] = 0
107
+ human_attr[1] = 0
108
+
109
+ if ('woman' in text or 'lady' in text) and (text_id == 0):
110
+ human_attr[0] = 1
111
+ human_attr[1] = 2
112
+
113
+ if (not changed[0]) and (text_id == 1):
114
+ # upper length
115
+ predefined_embeddings = model.encode(upper_length_text)
116
+ similarities = util.dot_score(user_embeddings,
117
+ predefined_embeddings)
118
+ arg_idx = torch.argmax(similarities).item()
119
+ attr[0] = upper_length_attr[upper_length_text[arg_idx]]
120
+ changed[0] = 1
121
+
122
+ if (not changed[1]) and ((text_num == 2 and text_id == 1) or
123
+ (text_num > 2 and text_id == 2)):
124
+ # lower length
125
+ predefined_embeddings = model.encode(lower_length_text)
126
+ similarities = util.dot_score(user_embeddings,
127
+ predefined_embeddings)
128
+ arg_idx = torch.argmax(similarities).item()
129
+ attr[1] = lower_length_attr[lower_length_text[arg_idx]]
130
+ changed[1] = 1
131
+
132
+ if (not changed[2]) and (text_id > 2):
133
+ # socks length
134
+ predefined_embeddings = model.encode(socks_length_text)
135
+ similarities = util.dot_score(user_embeddings,
136
+ predefined_embeddings)
137
+ arg_idx = torch.argmax(similarities).item()
138
+ if similarities[0][arg_idx] > 0.7:
139
+ attr[2] = arg_idx + 1
140
+ changed[2] = 1
141
+
142
+ if (not changed[3]) and (text_id > 2):
143
+ # hat
144
+ predefined_embeddings = model.encode(hat_text)
145
+ similarities = util.dot_score(user_embeddings,
146
+ predefined_embeddings)
147
+ if similarities[0][0] > 0.7:
148
+ attr[3] = 1
149
+ changed[3] = 1
150
+
151
+ if (not changed[4]) and (text_id > 2):
152
+ # glasses
153
+ predefined_embeddings = model.encode(eyeglasses_text)
154
+ similarities = util.dot_score(user_embeddings,
155
+ predefined_embeddings)
156
+ arg_idx = torch.argmax(similarities).item()
157
+ if similarities[0][arg_idx] > 0.7:
158
+ attr[4] = arg_idx + 1
159
+ changed[4] = 1
160
+
161
+ if (not changed[5]) and (text_id > 2):
162
+ # belt
163
+ predefined_embeddings = model.encode(belt_text)
164
+ similarities = util.dot_score(user_embeddings,
165
+ predefined_embeddings)
166
+ arg_idx = torch.argmax(similarities).item()
167
+ if similarities[0][arg_idx] > 0.7:
168
+ attr[5] = arg_idx + 1
169
+ changed[5] = 1
170
+
171
+ if (not changed[6]) and (text_id == 3):
172
+ # outer coverage
173
+ predefined_embeddings = model.encode(outer_shape_text)
174
+ similarities = util.dot_score(user_embeddings,
175
+ predefined_embeddings)
176
+ arg_idx = torch.argmax(similarities).item()
177
+ if similarities[0][arg_idx] > 0.7:
178
+ attr[6] = arg_idx
179
+ changed[6] = 1
180
+
181
+ if (not changed[10]) and (text_num == 2 and text_id == 1):
182
+ # dress_types
183
+ predefined_embeddings = model.encode(dress_types)
184
+ similarities = util.dot_score(user_embeddings,
185
+ predefined_embeddings)
186
+ similarity_skirt = util.dot_score(user_embeddings,
187
+ model.encode(skirt_types))
188
+ if similarities[0][0] > 0.5 and similarities[0][
189
+ 0] > similarity_skirt[0][0]:
190
+ attr[10] = 1
191
+ attr[7] = 0
192
+ attr[8] = 0
193
+ attr[9] = 0
194
+ attr[11] = 0
195
+ attr[12] = 0
196
+
197
+ changed[0] = 1
198
+ changed[10] = 1
199
+ changed[7] = 1
200
+ changed[8] = 1
201
+ changed[9] = 1
202
+ changed[11] = 1
203
+ changed[12] = 1
204
+
205
+ if (not changed[12]) and (text_num == 2 and text_id == 1):
206
+ # rompers_types
207
+ predefined_embeddings = model.encode(rompers_types)
208
+ similarities = util.dot_score(user_embeddings,
209
+ predefined_embeddings)
210
+ max_similarity = torch.max(similarities).item()
211
+ if max_similarity > 0.6:
212
+ attr[12] = 1
213
+ attr[7] = 0
214
+ attr[8] = 0
215
+ attr[9] = 0
216
+ attr[10] = 0
217
+ attr[11] = 0
218
+
219
+ changed[12] = 1
220
+ changed[7] = 1
221
+ changed[8] = 1
222
+ changed[9] = 1
223
+ changed[10] = 1
224
+ changed[11] = 1
225
+
226
+ if (not changed[7]) and (text_num > 2 and text_id == 1):
227
+ # upper_types
228
+ predefined_embeddings = model.encode(upper_types)
229
+ similarities = util.dot_score(user_embeddings,
230
+ predefined_embeddings)
231
+ max_similarity = torch.max(similarities).item()
232
+ if max_similarity > 0.6:
233
+ attr[7] = 1
234
+ changed[7] = 1
235
+
236
+ if (not changed[8]) and (text_id == 3):
237
+ # outer_types
238
+ predefined_embeddings = model.encode(outer_types)
239
+ similarities = util.dot_score(user_embeddings,
240
+ predefined_embeddings)
241
+ arg_idx = torch.argmax(similarities).item()
242
+ if similarities[0][arg_idx] > 0.7:
243
+ attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
244
+ attr[8] = 1
245
+ changed[8] = 1
246
+
247
+ if (not changed[9]) and (text_num > 2 and text_id == 2):
248
+ # skirt_types
249
+ predefined_embeddings = model.encode(skirt_types)
250
+ similarity_skirt = util.dot_score(user_embeddings,
251
+ predefined_embeddings)
252
+ similarity_dress = util.dot_score(user_embeddings,
253
+ model.encode(dress_types))
254
+ if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
255
+ 0] > similarity_dress[0][0]:
256
+ attr[9] = 1
257
+ attr[10] = 0
258
+ changed[9] = 1
259
+ changed[10] = 1
260
+
261
+ if (not changed[11]) and (text_num > 2 and text_id == 2):
262
+ # pant_types
263
+ predefined_embeddings = model.encode(pant_types)
264
+ similarities = util.dot_score(user_embeddings,
265
+ predefined_embeddings)
266
+ max_similarity = torch.max(similarities).item()
267
+ if max_similarity > 0.6:
268
+ attr[11] = 1
269
+ attr[9] = 0
270
+ attr[10] = 0
271
+ attr[12] = 0
272
+ changed[11] = 1
273
+ changed[9] = 1
274
+ changed[10] = 1
275
+ changed[12] = 1
276
+
277
+ return human_attr + attr
278
+
279
+
280
+ def generate_texture_attributes(user_text):
281
+ parsed_texts = user_text.split(',')
282
+
283
+ attr = []
284
+ for text in parsed_texts:
285
+ if ('pure color' in text) or ('solid color' in text):
286
+ attr.append(4)
287
+ elif ('spline' in text) or ('stripe' in text):
288
+ attr.append(3)
289
+ elif ('plaid' in text) or ('lattice' in text):
290
+ attr.append(5)
291
+ elif 'floral' in text:
292
+ attr.append(1)
293
+ elif 'denim' in text:
294
+ attr.append(0)
295
+ else:
296
+ attr.append(17)
297
+
298
+ if len(attr) == 1:
299
+ attr.append(attr[0])
300
+ attr.append(17)
301
+
302
+ if len(attr) == 2:
303
+ attr.append(17)
304
+
305
+ return attr
306
+
307
+
308
+ if __name__ == "__main__":
309
+ user_request = input('Enter your request: ')
310
+ while user_request != '\\q':
311
+ attr = generate_shape_attributes(user_request)
312
+ print(attr)
313
+ for attr_name, attr_value in zip(attr_names_list, attr):
314
+ print(attr_name, attr_value)
315
+ user_request = input('Enter your request: ')
utils/logger.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+
6
+ class MessageLogger():
7
+ """Message logger for printing.
8
+
9
+ Args:
10
+ opt (dict): Config. It contains the following keys:
11
+ name (str): Exp name.
12
+ logger (dict): Contains 'print_freq' (str) for logger interval.
13
+ train (dict): Contains 'niter' (int) for total iters.
14
+ use_tb_logger (bool): Use tensorboard logger.
15
+ start_iter (int): Start iter. Default: 1.
16
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
17
+ """
18
+
19
+ def __init__(self, opt, start_iter=1, tb_logger=None):
20
+ self.exp_name = opt['name']
21
+ self.interval = opt['print_freq']
22
+ self.start_iter = start_iter
23
+ self.max_iters = opt['max_iters']
24
+ self.use_tb_logger = opt['use_tb_logger']
25
+ self.tb_logger = tb_logger
26
+ self.start_time = time.time()
27
+ self.logger = get_root_logger()
28
+
29
+ def __call__(self, log_vars):
30
+ """Format logging message.
31
+
32
+ Args:
33
+ log_vars (dict): It contains the following keys:
34
+ epoch (int): Epoch number.
35
+ iter (int): Current iter.
36
+ lrs (list): List for learning rates.
37
+
38
+ time (float): Iter time.
39
+ data_time (float): Data time for each iter.
40
+ """
41
+ # epoch, iter, learning rates
42
+ epoch = log_vars.pop('epoch')
43
+ current_iter = log_vars.pop('iter')
44
+ lrs = log_vars.pop('lrs')
45
+
46
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
47
+ f'iter:{current_iter:8,d}, lr:(')
48
+ for v in lrs:
49
+ message += f'{v:.3e},'
50
+ message += ')] '
51
+
52
+ # time and estimated time
53
+ if 'time' in log_vars.keys():
54
+ iter_time = log_vars.pop('time')
55
+ data_time = log_vars.pop('data_time')
56
+
57
+ total_time = time.time() - self.start_time
58
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
59
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
60
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
61
+ message += f'[eta: {eta_str}, '
62
+ message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
63
+
64
+ # other items, especially losses
65
+ for k, v in log_vars.items():
66
+ message += f'{k}: {v:.4e} '
67
+ # tensorboard logger
68
+ if self.use_tb_logger and 'debug' not in self.exp_name:
69
+ self.tb_logger.add_scalar(k, v, current_iter)
70
+
71
+ self.logger.info(message)
72
+
73
+
74
+ def init_tb_logger(log_dir):
75
+ from torch.utils.tensorboard import SummaryWriter
76
+ tb_logger = SummaryWriter(log_dir=log_dir)
77
+ return tb_logger
78
+
79
+
80
+ def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
81
+ """Get the root logger.
82
+
83
+ The logger will be initialized if it has not been initialized. By default a
84
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
85
+ also be added.
86
+
87
+ Args:
88
+ logger_name (str): root logger name. Default: base.
89
+ log_file (str | None): The log filename. If specified, a FileHandler
90
+ will be added to the root logger.
91
+ log_level (int): The root logger level. Note that only the process of
92
+ rank 0 is affected, while other processes will set the level to
93
+ "Error" and be silent most of the time.
94
+
95
+ Returns:
96
+ logging.Logger: The root logger.
97
+ """
98
+ logger = logging.getLogger(logger_name)
99
+ # if the logger has been initialized, just return it
100
+ if logger.hasHandlers():
101
+ return logger
102
+
103
+ format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
104
+ logging.basicConfig(format=format_str, level=log_level)
105
+
106
+ if log_file is not None:
107
+ file_handler = logging.FileHandler(log_file, 'w')
108
+ file_handler.setFormatter(logging.Formatter(format_str))
109
+ file_handler.setLevel(log_level)
110
+ logger.addHandler(file_handler)
111
+
112
+ return logger
utils/options.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from collections import OrderedDict
4
+
5
+ import yaml
6
+
7
+
8
+ def ordered_yaml():
9
+ """Support OrderedDict for yaml.
10
+
11
+ Returns:
12
+ yaml Loader and Dumper.
13
+ """
14
+ try:
15
+ from yaml import CDumper as Dumper
16
+ from yaml import CLoader as Loader
17
+ except ImportError:
18
+ from yaml import Dumper, Loader
19
+
20
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
21
+
22
+ def dict_representer(dumper, data):
23
+ return dumper.represent_dict(data.items())
24
+
25
+ def dict_constructor(loader, node):
26
+ return OrderedDict(loader.construct_pairs(node))
27
+
28
+ Dumper.add_representer(OrderedDict, dict_representer)
29
+ Loader.add_constructor(_mapping_tag, dict_constructor)
30
+ return Loader, Dumper
31
+
32
+
33
+ def parse(opt_path, is_train=True):
34
+ """Parse option file.
35
+
36
+ Args:
37
+ opt_path (str): Option file path.
38
+ is_train (str): Indicate whether in training or not. Default: True.
39
+
40
+ Returns:
41
+ (dict): Options.
42
+ """
43
+ with open(opt_path, mode='r') as f:
44
+ Loader, _ = ordered_yaml()
45
+ opt = yaml.load(f, Loader=Loader)
46
+
47
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
48
+ if opt.get('set_CUDA_VISIBLE_DEVICES', None):
49
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
50
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
51
+ else:
52
+ print('gpu_list: ', gpu_list, flush=True)
53
+
54
+ opt['is_train'] = is_train
55
+
56
+ # paths
57
+ opt['path'] = {}
58
+ opt['path']['root'] = osp.abspath(
59
+ osp.join(__file__, osp.pardir, osp.pardir))
60
+ if is_train:
61
+ experiments_root = osp.join(opt['path']['root'], 'experiments',
62
+ opt['name'])
63
+ opt['path']['experiments_root'] = experiments_root
64
+ opt['path']['models'] = osp.join(experiments_root, 'models')
65
+ opt['path']['log'] = experiments_root
66
+ opt['path']['visualization'] = osp.join(experiments_root,
67
+ 'visualization')
68
+
69
+ # change some options for debug mode
70
+ if 'debug' in opt['name']:
71
+ opt['debug'] = True
72
+ opt['val_freq'] = 1
73
+ opt['print_freq'] = 1
74
+ opt['save_checkpoint_freq'] = 1
75
+ else: # test
76
+ results_root = osp.join(opt['path']['root'], 'results', opt['name'])
77
+ opt['path']['results_root'] = results_root
78
+ opt['path']['log'] = results_root
79
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
80
+
81
+ return opt
82
+
83
+
84
+ def dict2str(opt, indent_level=1):
85
+ """dict to string for printing options.
86
+
87
+ Args:
88
+ opt (dict): Option dict.
89
+ indent_level (int): Indent level. Default: 1.
90
+
91
+ Return:
92
+ (str): Option string for printing.
93
+ """
94
+ msg = ''
95
+ for k, v in opt.items():
96
+ if isinstance(v, dict):
97
+ msg += ' ' * (indent_level * 2) + k + ':[\n'
98
+ msg += dict2str(v, indent_level + 1)
99
+ msg += ' ' * (indent_level * 2) + ']\n'
100
+ else:
101
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
102
+ return msg
103
+
104
+
105
+ class NoneDict(dict):
106
+ """None dict. It will return none if key is not in the dict."""
107
+
108
+ def __missing__(self, key):
109
+ return None
110
+
111
+
112
+ def dict_to_nonedict(opt):
113
+ """Convert to NoneDict, which returns None for missing keys.
114
+
115
+ Args:
116
+ opt (dict): Option dict.
117
+
118
+ Returns:
119
+ (dict): NoneDict for options.
120
+ """
121
+ if isinstance(opt, dict):
122
+ new_opt = dict()
123
+ for key, sub_opt in opt.items():
124
+ new_opt[key] = dict_to_nonedict(sub_opt)
125
+ return NoneDict(**new_opt)
126
+ elif isinstance(opt, list):
127
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
128
+ else:
129
+ return opt
utils/util.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+ import time
6
+ from shutil import get_terminal_size
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ logger = logging.getLogger('base')
12
+
13
+
14
+ def make_exp_dirs(opt):
15
+ """Make dirs for experiments."""
16
+ path_opt = opt['path'].copy()
17
+ if opt['is_train']:
18
+ overwrite = True if 'debug' in opt['name'] else False
19
+ os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
20
+ os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
21
+ else:
22
+ os.makedirs(path_opt.pop('results_root'))
23
+
24
+
25
+ def set_random_seed(seed):
26
+ """Set random seeds."""
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+
34
+ class ProgressBar(object):
35
+ """A progress bar which can print the progress.
36
+
37
+ Modified from:
38
+ https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
39
+ """
40
+
41
+ def __init__(self, task_num=0, bar_width=50, start=True):
42
+ self.task_num = task_num
43
+ max_bar_width = self._get_max_bar_width()
44
+ self.bar_width = (
45
+ bar_width if bar_width <= max_bar_width else max_bar_width)
46
+ self.completed = 0
47
+ if start:
48
+ self.start()
49
+
50
+ def _get_max_bar_width(self):
51
+ terminal_width, _ = get_terminal_size()
52
+ max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
53
+ if max_bar_width < 10:
54
+ print(f'terminal width is too small ({terminal_width}), '
55
+ 'please consider widen the terminal for better '
56
+ 'progressbar visualization')
57
+ max_bar_width = 10
58
+ return max_bar_width
59
+
60
+ def start(self):
61
+ if self.task_num > 0:
62
+ sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
63
+ f'elapsed: 0s, ETA:\nStart...\n')
64
+ else:
65
+ sys.stdout.write('completed: 0, elapsed: 0s')
66
+ sys.stdout.flush()
67
+ self.start_time = time.time()
68
+
69
+ def update(self, msg='In progress...'):
70
+ self.completed += 1
71
+ elapsed = time.time() - self.start_time
72
+ fps = self.completed / elapsed
73
+ if self.task_num > 0:
74
+ percentage = self.completed / float(self.task_num)
75
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
76
+ mark_width = int(self.bar_width * percentage)
77
+ bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
78
+ sys.stdout.write('\033[2F') # cursor up 2 lines
79
+ sys.stdout.write(
80
+ '\033[J'
81
+ ) # clean the output (remove extra chars since last display)
82
+ sys.stdout.write(
83
+ f'[{bar_chars}] {self.completed}/{self.task_num}, '
84
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
85
+ f'ETA: {eta:5}s\n{msg}\n')
86
+ else:
87
+ sys.stdout.write(
88
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
89
+ f'{fps:.1f} tasks/s')
90
+ sys.stdout.flush()
91
+
92
+
93
+ class AverageMeter(object):
94
+ """
95
+ Computes and stores the average and current value
96
+ Imported from
97
+ https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
98
+ """
99
+
100
+ def __init__(self):
101
+ self.reset()
102
+
103
+ def reset(self):
104
+ self.val = 0
105
+ self.avg = 0 # running average = running sum / running count
106
+ self.sum = 0 # running sum
107
+ self.count = 0 # running count
108
+
109
+ def update(self, val, n=1):
110
+ # n = batch_size
111
+
112
+ # val = batch accuracy for an attribute
113
+ # self.val = val
114
+
115
+ # sum = 100 * accumulative correct predictions for this attribute
116
+ self.sum += val * n
117
+
118
+ # count = total samples so far
119
+ self.count += n
120
+
121
+ # avg = 100 * avg accuracy for this attribute
122
+ # for all the batches so far
123
+ self.avg = self.sum / self.count