Upload 9 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/language_utils.cpython-38.pyc +0 -0
- utils/__pycache__/options.cpython-38.pyc +0 -0
- utils/__pycache__/util.cpython-38.pyc +0 -0
- utils/language_utils.py +315 -0
- utils/logger.py +112 -0
- utils/options.py +129 -0
- utils/util.py +123 -0
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (119 Bytes). View file
|
|
utils/__pycache__/language_utils.cpython-38.pyc
ADDED
Binary file (5.73 kB). View file
|
|
utils/__pycache__/options.cpython-38.pyc
ADDED
Binary file (3.94 kB). View file
|
|
utils/__pycache__/util.cpython-38.pyc
ADDED
Binary file (3.81 kB). View file
|
|
utils/language_utils.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from curses import A_ATTRIBUTES
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
from pip import main
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
|
8 |
+
# predefined shape text
|
9 |
+
upper_length_text = [
|
10 |
+
'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
|
11 |
+
'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
|
12 |
+
'with short sleeves', 'medium-sleeve', 'medium sleeves',
|
13 |
+
'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
|
14 |
+
'long sleeves', 'with long sleeves'
|
15 |
+
]
|
16 |
+
upper_length_attr = {
|
17 |
+
'sleeveless': 0,
|
18 |
+
'without sleeves': 0,
|
19 |
+
'sleeves have been cut off': 0,
|
20 |
+
'tank top': 0,
|
21 |
+
'tank shirt': 0,
|
22 |
+
'muscle shirt': 0,
|
23 |
+
'short-sleeve': 1,
|
24 |
+
'with short sleeves': 1,
|
25 |
+
'short sleeves': 1,
|
26 |
+
'medium-sleeve': 2,
|
27 |
+
'with medium sleeves': 2,
|
28 |
+
'medium sleeves': 2,
|
29 |
+
'sleeves reach elbow': 2,
|
30 |
+
'long-sleeve': 3,
|
31 |
+
'long sleeves': 3,
|
32 |
+
'with long sleeves': 3
|
33 |
+
}
|
34 |
+
lower_length_text = [
|
35 |
+
'three-point', 'medium', 'short', 'covering knee', 'cropped',
|
36 |
+
'three-quarter', 'long', 'slack', 'of long length'
|
37 |
+
]
|
38 |
+
lower_length_attr = {
|
39 |
+
'three-point': 0,
|
40 |
+
'medium': 1,
|
41 |
+
'covering knee': 1,
|
42 |
+
'short': 1,
|
43 |
+
'cropped': 2,
|
44 |
+
'three-quarter': 2,
|
45 |
+
'long': 3,
|
46 |
+
'slack': 3,
|
47 |
+
'of long length': 3
|
48 |
+
}
|
49 |
+
socks_length_text = [
|
50 |
+
'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
|
51 |
+
]
|
52 |
+
socks_length_attr = {
|
53 |
+
'socks': 0,
|
54 |
+
'stocking': 1,
|
55 |
+
'pantyhose': 1,
|
56 |
+
'leggings': 1,
|
57 |
+
'sheer hosiery': 1
|
58 |
+
}
|
59 |
+
hat_text = ['hat', 'cap', 'chapeau']
|
60 |
+
eyeglasses_text = ['sunglasses']
|
61 |
+
belt_text = ['belt', 'with a dress tied around the waist']
|
62 |
+
outer_shape_text = [
|
63 |
+
'with outer clothing open', 'with outer clothing unzipped',
|
64 |
+
'covering inner clothes', 'with outer clothing zipped'
|
65 |
+
]
|
66 |
+
outer_shape_attr = {
|
67 |
+
'with outer clothing open': 0,
|
68 |
+
'with outer clothing unzipped': 0,
|
69 |
+
'covering inner clothes': 1,
|
70 |
+
'with outer clothing zipped': 1
|
71 |
+
}
|
72 |
+
|
73 |
+
upper_types = [
|
74 |
+
'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
|
75 |
+
]
|
76 |
+
outer_types = [
|
77 |
+
'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
|
78 |
+
'duffle', 'cardigan'
|
79 |
+
]
|
80 |
+
skirt_types = ['skirt']
|
81 |
+
dress_types = ['dress']
|
82 |
+
pant_types = ['jeans', 'pants', 'trousers']
|
83 |
+
rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
|
84 |
+
|
85 |
+
attr_names_list = [
|
86 |
+
'gender', 'hair length', '0 upper clothing length',
|
87 |
+
'1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
|
88 |
+
'6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
|
89 |
+
'9 skirt', '10 dress', '11 pants', '12 rompers'
|
90 |
+
]
|
91 |
+
|
92 |
+
|
93 |
+
def generate_shape_attributes(user_shape_texts):
|
94 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
95 |
+
parsed_texts = user_shape_texts.split(',')
|
96 |
+
|
97 |
+
text_num = len(parsed_texts)
|
98 |
+
|
99 |
+
human_attr = [0, 0]
|
100 |
+
attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
|
101 |
+
|
102 |
+
changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
103 |
+
for text_id, text in enumerate(parsed_texts):
|
104 |
+
user_embeddings = model.encode(text)
|
105 |
+
if ('man' in text) and (text_id == 0):
|
106 |
+
human_attr[0] = 0
|
107 |
+
human_attr[1] = 0
|
108 |
+
|
109 |
+
if ('woman' in text or 'lady' in text) and (text_id == 0):
|
110 |
+
human_attr[0] = 1
|
111 |
+
human_attr[1] = 2
|
112 |
+
|
113 |
+
if (not changed[0]) and (text_id == 1):
|
114 |
+
# upper length
|
115 |
+
predefined_embeddings = model.encode(upper_length_text)
|
116 |
+
similarities = util.dot_score(user_embeddings,
|
117 |
+
predefined_embeddings)
|
118 |
+
arg_idx = torch.argmax(similarities).item()
|
119 |
+
attr[0] = upper_length_attr[upper_length_text[arg_idx]]
|
120 |
+
changed[0] = 1
|
121 |
+
|
122 |
+
if (not changed[1]) and ((text_num == 2 and text_id == 1) or
|
123 |
+
(text_num > 2 and text_id == 2)):
|
124 |
+
# lower length
|
125 |
+
predefined_embeddings = model.encode(lower_length_text)
|
126 |
+
similarities = util.dot_score(user_embeddings,
|
127 |
+
predefined_embeddings)
|
128 |
+
arg_idx = torch.argmax(similarities).item()
|
129 |
+
attr[1] = lower_length_attr[lower_length_text[arg_idx]]
|
130 |
+
changed[1] = 1
|
131 |
+
|
132 |
+
if (not changed[2]) and (text_id > 2):
|
133 |
+
# socks length
|
134 |
+
predefined_embeddings = model.encode(socks_length_text)
|
135 |
+
similarities = util.dot_score(user_embeddings,
|
136 |
+
predefined_embeddings)
|
137 |
+
arg_idx = torch.argmax(similarities).item()
|
138 |
+
if similarities[0][arg_idx] > 0.7:
|
139 |
+
attr[2] = arg_idx + 1
|
140 |
+
changed[2] = 1
|
141 |
+
|
142 |
+
if (not changed[3]) and (text_id > 2):
|
143 |
+
# hat
|
144 |
+
predefined_embeddings = model.encode(hat_text)
|
145 |
+
similarities = util.dot_score(user_embeddings,
|
146 |
+
predefined_embeddings)
|
147 |
+
if similarities[0][0] > 0.7:
|
148 |
+
attr[3] = 1
|
149 |
+
changed[3] = 1
|
150 |
+
|
151 |
+
if (not changed[4]) and (text_id > 2):
|
152 |
+
# glasses
|
153 |
+
predefined_embeddings = model.encode(eyeglasses_text)
|
154 |
+
similarities = util.dot_score(user_embeddings,
|
155 |
+
predefined_embeddings)
|
156 |
+
arg_idx = torch.argmax(similarities).item()
|
157 |
+
if similarities[0][arg_idx] > 0.7:
|
158 |
+
attr[4] = arg_idx + 1
|
159 |
+
changed[4] = 1
|
160 |
+
|
161 |
+
if (not changed[5]) and (text_id > 2):
|
162 |
+
# belt
|
163 |
+
predefined_embeddings = model.encode(belt_text)
|
164 |
+
similarities = util.dot_score(user_embeddings,
|
165 |
+
predefined_embeddings)
|
166 |
+
arg_idx = torch.argmax(similarities).item()
|
167 |
+
if similarities[0][arg_idx] > 0.7:
|
168 |
+
attr[5] = arg_idx + 1
|
169 |
+
changed[5] = 1
|
170 |
+
|
171 |
+
if (not changed[6]) and (text_id == 3):
|
172 |
+
# outer coverage
|
173 |
+
predefined_embeddings = model.encode(outer_shape_text)
|
174 |
+
similarities = util.dot_score(user_embeddings,
|
175 |
+
predefined_embeddings)
|
176 |
+
arg_idx = torch.argmax(similarities).item()
|
177 |
+
if similarities[0][arg_idx] > 0.7:
|
178 |
+
attr[6] = arg_idx
|
179 |
+
changed[6] = 1
|
180 |
+
|
181 |
+
if (not changed[10]) and (text_num == 2 and text_id == 1):
|
182 |
+
# dress_types
|
183 |
+
predefined_embeddings = model.encode(dress_types)
|
184 |
+
similarities = util.dot_score(user_embeddings,
|
185 |
+
predefined_embeddings)
|
186 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
187 |
+
model.encode(skirt_types))
|
188 |
+
if similarities[0][0] > 0.5 and similarities[0][
|
189 |
+
0] > similarity_skirt[0][0]:
|
190 |
+
attr[10] = 1
|
191 |
+
attr[7] = 0
|
192 |
+
attr[8] = 0
|
193 |
+
attr[9] = 0
|
194 |
+
attr[11] = 0
|
195 |
+
attr[12] = 0
|
196 |
+
|
197 |
+
changed[0] = 1
|
198 |
+
changed[10] = 1
|
199 |
+
changed[7] = 1
|
200 |
+
changed[8] = 1
|
201 |
+
changed[9] = 1
|
202 |
+
changed[11] = 1
|
203 |
+
changed[12] = 1
|
204 |
+
|
205 |
+
if (not changed[12]) and (text_num == 2 and text_id == 1):
|
206 |
+
# rompers_types
|
207 |
+
predefined_embeddings = model.encode(rompers_types)
|
208 |
+
similarities = util.dot_score(user_embeddings,
|
209 |
+
predefined_embeddings)
|
210 |
+
max_similarity = torch.max(similarities).item()
|
211 |
+
if max_similarity > 0.6:
|
212 |
+
attr[12] = 1
|
213 |
+
attr[7] = 0
|
214 |
+
attr[8] = 0
|
215 |
+
attr[9] = 0
|
216 |
+
attr[10] = 0
|
217 |
+
attr[11] = 0
|
218 |
+
|
219 |
+
changed[12] = 1
|
220 |
+
changed[7] = 1
|
221 |
+
changed[8] = 1
|
222 |
+
changed[9] = 1
|
223 |
+
changed[10] = 1
|
224 |
+
changed[11] = 1
|
225 |
+
|
226 |
+
if (not changed[7]) and (text_num > 2 and text_id == 1):
|
227 |
+
# upper_types
|
228 |
+
predefined_embeddings = model.encode(upper_types)
|
229 |
+
similarities = util.dot_score(user_embeddings,
|
230 |
+
predefined_embeddings)
|
231 |
+
max_similarity = torch.max(similarities).item()
|
232 |
+
if max_similarity > 0.6:
|
233 |
+
attr[7] = 1
|
234 |
+
changed[7] = 1
|
235 |
+
|
236 |
+
if (not changed[8]) and (text_id == 3):
|
237 |
+
# outer_types
|
238 |
+
predefined_embeddings = model.encode(outer_types)
|
239 |
+
similarities = util.dot_score(user_embeddings,
|
240 |
+
predefined_embeddings)
|
241 |
+
arg_idx = torch.argmax(similarities).item()
|
242 |
+
if similarities[0][arg_idx] > 0.7:
|
243 |
+
attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
|
244 |
+
attr[8] = 1
|
245 |
+
changed[8] = 1
|
246 |
+
|
247 |
+
if (not changed[9]) and (text_num > 2 and text_id == 2):
|
248 |
+
# skirt_types
|
249 |
+
predefined_embeddings = model.encode(skirt_types)
|
250 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
251 |
+
predefined_embeddings)
|
252 |
+
similarity_dress = util.dot_score(user_embeddings,
|
253 |
+
model.encode(dress_types))
|
254 |
+
if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
|
255 |
+
0] > similarity_dress[0][0]:
|
256 |
+
attr[9] = 1
|
257 |
+
attr[10] = 0
|
258 |
+
changed[9] = 1
|
259 |
+
changed[10] = 1
|
260 |
+
|
261 |
+
if (not changed[11]) and (text_num > 2 and text_id == 2):
|
262 |
+
# pant_types
|
263 |
+
predefined_embeddings = model.encode(pant_types)
|
264 |
+
similarities = util.dot_score(user_embeddings,
|
265 |
+
predefined_embeddings)
|
266 |
+
max_similarity = torch.max(similarities).item()
|
267 |
+
if max_similarity > 0.6:
|
268 |
+
attr[11] = 1
|
269 |
+
attr[9] = 0
|
270 |
+
attr[10] = 0
|
271 |
+
attr[12] = 0
|
272 |
+
changed[11] = 1
|
273 |
+
changed[9] = 1
|
274 |
+
changed[10] = 1
|
275 |
+
changed[12] = 1
|
276 |
+
|
277 |
+
return human_attr + attr
|
278 |
+
|
279 |
+
|
280 |
+
def generate_texture_attributes(user_text):
|
281 |
+
parsed_texts = user_text.split(',')
|
282 |
+
|
283 |
+
attr = []
|
284 |
+
for text in parsed_texts:
|
285 |
+
if ('pure color' in text) or ('solid color' in text):
|
286 |
+
attr.append(4)
|
287 |
+
elif ('spline' in text) or ('stripe' in text):
|
288 |
+
attr.append(3)
|
289 |
+
elif ('plaid' in text) or ('lattice' in text):
|
290 |
+
attr.append(5)
|
291 |
+
elif 'floral' in text:
|
292 |
+
attr.append(1)
|
293 |
+
elif 'denim' in text:
|
294 |
+
attr.append(0)
|
295 |
+
else:
|
296 |
+
attr.append(17)
|
297 |
+
|
298 |
+
if len(attr) == 1:
|
299 |
+
attr.append(attr[0])
|
300 |
+
attr.append(17)
|
301 |
+
|
302 |
+
if len(attr) == 2:
|
303 |
+
attr.append(17)
|
304 |
+
|
305 |
+
return attr
|
306 |
+
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
user_request = input('Enter your request: ')
|
310 |
+
while user_request != '\\q':
|
311 |
+
attr = generate_shape_attributes(user_request)
|
312 |
+
print(attr)
|
313 |
+
for attr_name, attr_value in zip(attr_names_list, attr):
|
314 |
+
print(attr_name, attr_value)
|
315 |
+
user_request = input('Enter your request: ')
|
utils/logger.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class MessageLogger():
|
7 |
+
"""Message logger for printing.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
opt (dict): Config. It contains the following keys:
|
11 |
+
name (str): Exp name.
|
12 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
13 |
+
train (dict): Contains 'niter' (int) for total iters.
|
14 |
+
use_tb_logger (bool): Use tensorboard logger.
|
15 |
+
start_iter (int): Start iter. Default: 1.
|
16 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. DefaultοΌ None.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, opt, start_iter=1, tb_logger=None):
|
20 |
+
self.exp_name = opt['name']
|
21 |
+
self.interval = opt['print_freq']
|
22 |
+
self.start_iter = start_iter
|
23 |
+
self.max_iters = opt['max_iters']
|
24 |
+
self.use_tb_logger = opt['use_tb_logger']
|
25 |
+
self.tb_logger = tb_logger
|
26 |
+
self.start_time = time.time()
|
27 |
+
self.logger = get_root_logger()
|
28 |
+
|
29 |
+
def __call__(self, log_vars):
|
30 |
+
"""Format logging message.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
log_vars (dict): It contains the following keys:
|
34 |
+
epoch (int): Epoch number.
|
35 |
+
iter (int): Current iter.
|
36 |
+
lrs (list): List for learning rates.
|
37 |
+
|
38 |
+
time (float): Iter time.
|
39 |
+
data_time (float): Data time for each iter.
|
40 |
+
"""
|
41 |
+
# epoch, iter, learning rates
|
42 |
+
epoch = log_vars.pop('epoch')
|
43 |
+
current_iter = log_vars.pop('iter')
|
44 |
+
lrs = log_vars.pop('lrs')
|
45 |
+
|
46 |
+
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
|
47 |
+
f'iter:{current_iter:8,d}, lr:(')
|
48 |
+
for v in lrs:
|
49 |
+
message += f'{v:.3e},'
|
50 |
+
message += ')] '
|
51 |
+
|
52 |
+
# time and estimated time
|
53 |
+
if 'time' in log_vars.keys():
|
54 |
+
iter_time = log_vars.pop('time')
|
55 |
+
data_time = log_vars.pop('data_time')
|
56 |
+
|
57 |
+
total_time = time.time() - self.start_time
|
58 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
59 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
60 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
61 |
+
message += f'[eta: {eta_str}, '
|
62 |
+
message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
|
63 |
+
|
64 |
+
# other items, especially losses
|
65 |
+
for k, v in log_vars.items():
|
66 |
+
message += f'{k}: {v:.4e} '
|
67 |
+
# tensorboard logger
|
68 |
+
if self.use_tb_logger and 'debug' not in self.exp_name:
|
69 |
+
self.tb_logger.add_scalar(k, v, current_iter)
|
70 |
+
|
71 |
+
self.logger.info(message)
|
72 |
+
|
73 |
+
|
74 |
+
def init_tb_logger(log_dir):
|
75 |
+
from torch.utils.tensorboard import SummaryWriter
|
76 |
+
tb_logger = SummaryWriter(log_dir=log_dir)
|
77 |
+
return tb_logger
|
78 |
+
|
79 |
+
|
80 |
+
def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
|
81 |
+
"""Get the root logger.
|
82 |
+
|
83 |
+
The logger will be initialized if it has not been initialized. By default a
|
84 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
85 |
+
also be added.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
logger_name (str): root logger name. Default: base.
|
89 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
90 |
+
will be added to the root logger.
|
91 |
+
log_level (int): The root logger level. Note that only the process of
|
92 |
+
rank 0 is affected, while other processes will set the level to
|
93 |
+
"Error" and be silent most of the time.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
logging.Logger: The root logger.
|
97 |
+
"""
|
98 |
+
logger = logging.getLogger(logger_name)
|
99 |
+
# if the logger has been initialized, just return it
|
100 |
+
if logger.hasHandlers():
|
101 |
+
return logger
|
102 |
+
|
103 |
+
format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
|
104 |
+
logging.basicConfig(format=format_str, level=log_level)
|
105 |
+
|
106 |
+
if log_file is not None:
|
107 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
108 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
109 |
+
file_handler.setLevel(log_level)
|
110 |
+
logger.addHandler(file_handler)
|
111 |
+
|
112 |
+
return logger
|
utils/options.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
|
8 |
+
def ordered_yaml():
|
9 |
+
"""Support OrderedDict for yaml.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
yaml Loader and Dumper.
|
13 |
+
"""
|
14 |
+
try:
|
15 |
+
from yaml import CDumper as Dumper
|
16 |
+
from yaml import CLoader as Loader
|
17 |
+
except ImportError:
|
18 |
+
from yaml import Dumper, Loader
|
19 |
+
|
20 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
21 |
+
|
22 |
+
def dict_representer(dumper, data):
|
23 |
+
return dumper.represent_dict(data.items())
|
24 |
+
|
25 |
+
def dict_constructor(loader, node):
|
26 |
+
return OrderedDict(loader.construct_pairs(node))
|
27 |
+
|
28 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
29 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
30 |
+
return Loader, Dumper
|
31 |
+
|
32 |
+
|
33 |
+
def parse(opt_path, is_train=True):
|
34 |
+
"""Parse option file.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
opt_path (str): Option file path.
|
38 |
+
is_train (str): Indicate whether in training or not. Default: True.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
(dict): Options.
|
42 |
+
"""
|
43 |
+
with open(opt_path, mode='r') as f:
|
44 |
+
Loader, _ = ordered_yaml()
|
45 |
+
opt = yaml.load(f, Loader=Loader)
|
46 |
+
|
47 |
+
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
48 |
+
if opt.get('set_CUDA_VISIBLE_DEVICES', None):
|
49 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
50 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
|
51 |
+
else:
|
52 |
+
print('gpu_list: ', gpu_list, flush=True)
|
53 |
+
|
54 |
+
opt['is_train'] = is_train
|
55 |
+
|
56 |
+
# paths
|
57 |
+
opt['path'] = {}
|
58 |
+
opt['path']['root'] = osp.abspath(
|
59 |
+
osp.join(__file__, osp.pardir, osp.pardir))
|
60 |
+
if is_train:
|
61 |
+
experiments_root = osp.join(opt['path']['root'], 'experiments',
|
62 |
+
opt['name'])
|
63 |
+
opt['path']['experiments_root'] = experiments_root
|
64 |
+
opt['path']['models'] = osp.join(experiments_root, 'models')
|
65 |
+
opt['path']['log'] = experiments_root
|
66 |
+
opt['path']['visualization'] = osp.join(experiments_root,
|
67 |
+
'visualization')
|
68 |
+
|
69 |
+
# change some options for debug mode
|
70 |
+
if 'debug' in opt['name']:
|
71 |
+
opt['debug'] = True
|
72 |
+
opt['val_freq'] = 1
|
73 |
+
opt['print_freq'] = 1
|
74 |
+
opt['save_checkpoint_freq'] = 1
|
75 |
+
else: # test
|
76 |
+
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
77 |
+
opt['path']['results_root'] = results_root
|
78 |
+
opt['path']['log'] = results_root
|
79 |
+
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
80 |
+
|
81 |
+
return opt
|
82 |
+
|
83 |
+
|
84 |
+
def dict2str(opt, indent_level=1):
|
85 |
+
"""dict to string for printing options.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
opt (dict): Option dict.
|
89 |
+
indent_level (int): Indent level. Default: 1.
|
90 |
+
|
91 |
+
Return:
|
92 |
+
(str): Option string for printing.
|
93 |
+
"""
|
94 |
+
msg = ''
|
95 |
+
for k, v in opt.items():
|
96 |
+
if isinstance(v, dict):
|
97 |
+
msg += ' ' * (indent_level * 2) + k + ':[\n'
|
98 |
+
msg += dict2str(v, indent_level + 1)
|
99 |
+
msg += ' ' * (indent_level * 2) + ']\n'
|
100 |
+
else:
|
101 |
+
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
102 |
+
return msg
|
103 |
+
|
104 |
+
|
105 |
+
class NoneDict(dict):
|
106 |
+
"""None dict. It will return none if key is not in the dict."""
|
107 |
+
|
108 |
+
def __missing__(self, key):
|
109 |
+
return None
|
110 |
+
|
111 |
+
|
112 |
+
def dict_to_nonedict(opt):
|
113 |
+
"""Convert to NoneDict, which returns None for missing keys.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
opt (dict): Option dict.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
(dict): NoneDict for options.
|
120 |
+
"""
|
121 |
+
if isinstance(opt, dict):
|
122 |
+
new_opt = dict()
|
123 |
+
for key, sub_opt in opt.items():
|
124 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
125 |
+
return NoneDict(**new_opt)
|
126 |
+
elif isinstance(opt, list):
|
127 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
128 |
+
else:
|
129 |
+
return opt
|
utils/util.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
from shutil import get_terminal_size
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
logger = logging.getLogger('base')
|
12 |
+
|
13 |
+
|
14 |
+
def make_exp_dirs(opt):
|
15 |
+
"""Make dirs for experiments."""
|
16 |
+
path_opt = opt['path'].copy()
|
17 |
+
if opt['is_train']:
|
18 |
+
overwrite = True if 'debug' in opt['name'] else False
|
19 |
+
os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
|
20 |
+
os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
|
21 |
+
else:
|
22 |
+
os.makedirs(path_opt.pop('results_root'))
|
23 |
+
|
24 |
+
|
25 |
+
def set_random_seed(seed):
|
26 |
+
"""Set random seeds."""
|
27 |
+
random.seed(seed)
|
28 |
+
np.random.seed(seed)
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
torch.cuda.manual_seed(seed)
|
31 |
+
torch.cuda.manual_seed_all(seed)
|
32 |
+
|
33 |
+
|
34 |
+
class ProgressBar(object):
|
35 |
+
"""A progress bar which can print the progress.
|
36 |
+
|
37 |
+
Modified from:
|
38 |
+
https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, task_num=0, bar_width=50, start=True):
|
42 |
+
self.task_num = task_num
|
43 |
+
max_bar_width = self._get_max_bar_width()
|
44 |
+
self.bar_width = (
|
45 |
+
bar_width if bar_width <= max_bar_width else max_bar_width)
|
46 |
+
self.completed = 0
|
47 |
+
if start:
|
48 |
+
self.start()
|
49 |
+
|
50 |
+
def _get_max_bar_width(self):
|
51 |
+
terminal_width, _ = get_terminal_size()
|
52 |
+
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
|
53 |
+
if max_bar_width < 10:
|
54 |
+
print(f'terminal width is too small ({terminal_width}), '
|
55 |
+
'please consider widen the terminal for better '
|
56 |
+
'progressbar visualization')
|
57 |
+
max_bar_width = 10
|
58 |
+
return max_bar_width
|
59 |
+
|
60 |
+
def start(self):
|
61 |
+
if self.task_num > 0:
|
62 |
+
sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
|
63 |
+
f'elapsed: 0s, ETA:\nStart...\n')
|
64 |
+
else:
|
65 |
+
sys.stdout.write('completed: 0, elapsed: 0s')
|
66 |
+
sys.stdout.flush()
|
67 |
+
self.start_time = time.time()
|
68 |
+
|
69 |
+
def update(self, msg='In progress...'):
|
70 |
+
self.completed += 1
|
71 |
+
elapsed = time.time() - self.start_time
|
72 |
+
fps = self.completed / elapsed
|
73 |
+
if self.task_num > 0:
|
74 |
+
percentage = self.completed / float(self.task_num)
|
75 |
+
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
76 |
+
mark_width = int(self.bar_width * percentage)
|
77 |
+
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
|
78 |
+
sys.stdout.write('\033[2F') # cursor up 2 lines
|
79 |
+
sys.stdout.write(
|
80 |
+
'\033[J'
|
81 |
+
) # clean the output (remove extra chars since last display)
|
82 |
+
sys.stdout.write(
|
83 |
+
f'[{bar_chars}] {self.completed}/{self.task_num}, '
|
84 |
+
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
|
85 |
+
f'ETA: {eta:5}s\n{msg}\n')
|
86 |
+
else:
|
87 |
+
sys.stdout.write(
|
88 |
+
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
|
89 |
+
f'{fps:.1f} tasks/s')
|
90 |
+
sys.stdout.flush()
|
91 |
+
|
92 |
+
|
93 |
+
class AverageMeter(object):
|
94 |
+
"""
|
95 |
+
Computes and stores the average and current value
|
96 |
+
Imported from
|
97 |
+
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self):
|
101 |
+
self.reset()
|
102 |
+
|
103 |
+
def reset(self):
|
104 |
+
self.val = 0
|
105 |
+
self.avg = 0 # running average = running sum / running count
|
106 |
+
self.sum = 0 # running sum
|
107 |
+
self.count = 0 # running count
|
108 |
+
|
109 |
+
def update(self, val, n=1):
|
110 |
+
# n = batch_size
|
111 |
+
|
112 |
+
# val = batch accuracy for an attribute
|
113 |
+
# self.val = val
|
114 |
+
|
115 |
+
# sum = 100 * accumulative correct predictions for this attribute
|
116 |
+
self.sum += val * n
|
117 |
+
|
118 |
+
# count = total samples so far
|
119 |
+
self.count += n
|
120 |
+
|
121 |
+
# avg = 100 * avg accuracy for this attribute
|
122 |
+
# for all the batches so far
|
123 |
+
self.avg = self.sum / self.count
|