Cyril666 commited on
Commit
e5414eb
·
1 Parent(s): 616dad3

First model version

Browse files
Files changed (5) hide show
  1. callbacks.py +0 -360
  2. dataset.py +0 -278
  3. losses.py +0 -72
  4. main.py +0 -246
  5. transforms.py +0 -329
callbacks.py DELETED
@@ -1,360 +0,0 @@
1
- import logging
2
- import shutil
3
- import time
4
-
5
- import editdistance as ed
6
- import torchvision.utils as vutils
7
- from fastai.callbacks.tensorboard import (LearnerTensorboardWriter,
8
- SummaryWriter, TBWriteRequest,
9
- asyncTBWriter)
10
- from fastai.vision import *
11
- from torch.nn.parallel import DistributedDataParallel
12
- from torchvision import transforms
13
-
14
- import dataset
15
- from utils import CharsetMapper, Timer, blend_mask
16
-
17
-
18
- class IterationCallback(LearnerTensorboardWriter):
19
- "A `TrackerCallback` that monitor in each iteration."
20
- def __init__(self, learn:Learner, name:str='model', checpoint_keep_num=5,
21
- show_iters:int=50, eval_iters:int=1000, save_iters:int=20000,
22
- start_iters:int=0, stats_iters=20000):
23
- #if self.learn.rank is not None: time.sleep(self.learn.rank) # keep all event files
24
- super().__init__(learn, base_dir='.', name=learn.path, loss_iters=show_iters,
25
- stats_iters=stats_iters, hist_iters=stats_iters)
26
- self.name, self.bestname = Path(name).name, f'best-{Path(name).name}'
27
- self.show_iters = show_iters
28
- self.eval_iters = eval_iters
29
- self.save_iters = save_iters
30
- self.start_iters = start_iters
31
- self.checpoint_keep_num = checpoint_keep_num
32
- self.metrics_root = 'metrics/' # rewrite
33
- self.timer = Timer()
34
- self.host = self.learn.rank is None or self.learn.rank == 0
35
-
36
- def _write_metrics(self, iteration:int, names:List[str], last_metrics:MetricsList)->None:
37
- "Writes training metrics to Tensorboard."
38
- for i, name in enumerate(names):
39
- if last_metrics is None or len(last_metrics) < i+1: return
40
- scalar_value = last_metrics[i]
41
- self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
42
-
43
- def _write_sub_loss(self, iteration:int, last_losses:dict)->None:
44
- "Writes sub loss to Tensorboard."
45
- for name, loss in last_losses.items():
46
- scalar_value = to_np(loss)
47
- tag = self.metrics_root + name
48
- self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
49
-
50
- def _save(self, name):
51
- if isinstance(self.learn.model, DistributedDataParallel):
52
- tmp = self.learn.model
53
- self.learn.model = self.learn.model.module
54
- self.learn.save(name)
55
- self.learn.model = tmp
56
- else: self.learn.save(name)
57
-
58
- def _validate(self, dl=None, callbacks=None, metrics=None, keeped_items=False):
59
- "Validate on `dl` with potential `callbacks` and `metrics`."
60
- dl = ifnone(dl, self.learn.data.valid_dl)
61
- metrics = ifnone(metrics, self.learn.metrics)
62
- cb_handler = CallbackHandler(ifnone(callbacks, []), metrics)
63
- cb_handler.on_train_begin(1, None, metrics); cb_handler.on_epoch_begin()
64
- if keeped_items: cb_handler.state_dict.update(dict(keeped_items=[]))
65
- val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler)
66
- cb_handler.on_epoch_end(val_metrics)
67
- if keeped_items: return cb_handler.state_dict['keeped_items']
68
- else: return cb_handler.state_dict['last_metrics']
69
-
70
- def jump_to_epoch_iter(self, epoch:int, iteration:int)->None:
71
- try:
72
- self.learn.load(f'{self.name}_{epoch}_{iteration}', purge=False)
73
- logging.info(f'Loaded {self.name}_{epoch}_{iteration}')
74
- except: logging.info(f'Model {self.name}_{epoch}_{iteration} not found.')
75
-
76
- def on_train_begin(self, n_epochs, **kwargs):
77
- # TODO: can not write graph here
78
- # super().on_train_begin(**kwargs)
79
- self.best = -float('inf')
80
- self.timer.tic()
81
- if self.host:
82
- checkpoint_path = self.learn.path/'checkpoint.yaml'
83
- if checkpoint_path.exists():
84
- os.remove(checkpoint_path)
85
- open(checkpoint_path, 'w').close()
86
- return {'skip_validate': True, 'iteration':self.start_iters} # disable default validate
87
-
88
- def on_batch_begin(self, **kwargs:Any)->None:
89
- self.timer.toc_data()
90
- super().on_batch_begin(**kwargs)
91
-
92
- def on_batch_end(self, iteration, epoch, last_loss, smooth_loss, train, **kwargs):
93
- super().on_batch_end(last_loss, iteration, train, **kwargs)
94
- if iteration == 0: return
95
-
96
- if iteration % self.loss_iters == 0:
97
- last_losses = self.learn.loss_func.last_losses
98
- self._write_sub_loss(iteration=iteration, last_losses=last_losses)
99
- self.tbwriter.add_scalar(tag=self.metrics_root + 'lr',
100
- scalar_value=self.opt.lr, global_step=iteration)
101
-
102
- if iteration % self.show_iters == 0:
103
- log_str = f'epoch {epoch} iter {iteration}: loss = {last_loss:6.4f}, ' \
104
- f'smooth loss = {smooth_loss:6.4f}'
105
- logging.info(log_str)
106
- # log_str = f'data time = {self.timer.data_diff:.4f}s, runing time = {self.timer.running_diff:.4f}s'
107
- # logging.info(log_str)
108
-
109
- if iteration % self.eval_iters == 0:
110
- # TODO: or remove time to on_epoch_end
111
- # 1. Record time
112
- log_str = f'average data time = {self.timer.average_data_time():.4f}s, ' \
113
- f'average running time = {self.timer.average_running_time():.4f}s'
114
- logging.info(log_str)
115
-
116
- # 2. Call validate
117
- last_metrics = self._validate()
118
- self.learn.model.train()
119
- log_str = f'epoch {epoch} iter {iteration}: eval loss = {last_metrics[0]:6.4f}, ' \
120
- f'ccr = {last_metrics[1]:6.4f}, cwr = {last_metrics[2]:6.4f}, ' \
121
- f'ted = {last_metrics[3]:6.4f}, ned = {last_metrics[4]:6.4f}, ' \
122
- f'ted/w = {last_metrics[5]:6.4f}, '
123
- logging.info(log_str)
124
- names = ['eval_loss', 'ccr', 'cwr', 'ted', 'ned', 'ted/w']
125
- self._write_metrics(iteration, names, last_metrics)
126
-
127
- # 3. Save best model
128
- current = last_metrics[2]
129
- if current is not None and current > self.best:
130
- logging.info(f'Better model found at epoch {epoch}, '\
131
- f'iter {iteration} with accuracy value: {current:6.4f}.')
132
- self.best = current
133
- self._save(f'{self.bestname}')
134
-
135
- if iteration % self.save_iters == 0 and self.host:
136
- logging.info(f'Save model {self.name}_{epoch}_{iteration}')
137
- filename = f'{self.name}_{epoch}_{iteration}'
138
- self._save(filename)
139
-
140
- checkpoint_path = self.learn.path/'checkpoint.yaml'
141
- if not checkpoint_path.exists():
142
- open(checkpoint_path, 'w').close()
143
- with open(checkpoint_path, 'r') as file:
144
- checkpoints = yaml.load(file, Loader=yaml.FullLoader) or dict()
145
- checkpoints['all_checkpoints'] = (
146
- checkpoints.get('all_checkpoints') or list())
147
- checkpoints['all_checkpoints'].insert(0, filename)
148
- if len(checkpoints['all_checkpoints']) > self.checpoint_keep_num:
149
- removed_checkpoint = checkpoints['all_checkpoints'].pop()
150
- removed_checkpoint = self.learn.path/self.learn.model_dir/f'{removed_checkpoint}.pth'
151
- os.remove(removed_checkpoint)
152
- checkpoints['current_checkpoint'] = filename
153
- with open(checkpoint_path, 'w') as file:
154
- yaml.dump(checkpoints, file)
155
-
156
-
157
- self.timer.toc_running()
158
-
159
- def on_train_end(self, **kwargs):
160
- #self.learn.load(f'{self.bestname}', purge=False)
161
- pass
162
-
163
- def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
164
- self._write_embedding(iteration=iteration)
165
-
166
-
167
- class TextAccuracy(Callback):
168
- _names = ['ccr', 'cwr', 'ted', 'ned', 'ted/w']
169
- def __init__(self, charset_path, max_length, case_sensitive, model_eval):
170
- self.charset_path = charset_path
171
- self.max_length = max_length
172
- self.case_sensitive = case_sensitive
173
- self.charset = CharsetMapper(charset_path, self.max_length)
174
- self.names = self._names
175
-
176
- self.model_eval = model_eval or 'alignment'
177
- assert self.model_eval in ['vision', 'language', 'alignment']
178
-
179
- def on_epoch_begin(self, **kwargs):
180
- self.total_num_char = 0.
181
- self.total_num_word = 0.
182
- self.correct_num_char = 0.
183
- self.correct_num_word = 0.
184
- self.total_ed = 0.
185
- self.total_ned = 0.
186
-
187
- def _get_output(self, last_output):
188
- if isinstance(last_output, (tuple, list)):
189
- for res in last_output:
190
- if res['name'] == self.model_eval: output = res
191
- else: output = last_output
192
- return output
193
-
194
- def _update_output(self, last_output, items):
195
- if isinstance(last_output, (tuple, list)):
196
- for res in last_output:
197
- if res['name'] == self.model_eval: res.update(items)
198
- else: last_output.update(items)
199
- return last_output
200
-
201
- def on_batch_end(self, last_output, last_target, **kwargs):
202
- output = self._get_output(last_output)
203
- logits, pt_lengths = output['logits'], output['pt_lengths']
204
- pt_text, pt_scores, pt_lengths_ = self.decode(logits)
205
- assert (pt_lengths == pt_lengths_).all(), f'{pt_lengths} != {pt_lengths_} for {pt_text}'
206
- last_output = self._update_output(last_output, {'pt_text':pt_text, 'pt_scores':pt_scores})
207
-
208
- pt_text = [self.charset.trim(t) for t in pt_text]
209
- label = last_target[0]
210
- if label.dim() == 3: label = label.argmax(dim=-1) # one-hot label
211
- gt_text = [self.charset.get_text(l, trim=True) for l in label]
212
-
213
- for i in range(len(gt_text)):
214
- if not self.case_sensitive:
215
- gt_text[i], pt_text[i] = gt_text[i].lower(), pt_text[i].lower()
216
- distance = ed.eval(gt_text[i], pt_text[i])
217
- self.total_ed += distance
218
- self.total_ned += float(distance) / max(len(gt_text[i]), 1)
219
-
220
- if gt_text[i] == pt_text[i]:
221
- self.correct_num_word += 1
222
- self.total_num_word += 1
223
-
224
- for j in range(min(len(gt_text[i]), len(pt_text[i]))):
225
- if gt_text[i][j] == pt_text[i][j]:
226
- self.correct_num_char += 1
227
- self.total_num_char += len(gt_text[i])
228
-
229
- return {'last_output': last_output}
230
-
231
- def on_epoch_end(self, last_metrics, **kwargs):
232
- mets = [self.correct_num_char / self.total_num_char,
233
- self.correct_num_word / self.total_num_word,
234
- self.total_ed,
235
- self.total_ned,
236
- self.total_ed / self.total_num_word]
237
- return add_metrics(last_metrics, mets)
238
-
239
- def decode(self, logit):
240
- """ Greed decode """
241
- # TODO: test running time and decode on GPU
242
- out = F.softmax(logit, dim=2)
243
- pt_text, pt_scores, pt_lengths = [], [], []
244
- for o in out:
245
- text = self.charset.get_text(o.argmax(dim=1), padding=False, trim=False)
246
- text = text.split(self.charset.null_char)[0] # end at end-token
247
- pt_text.append(text)
248
- pt_scores.append(o.max(dim=1)[0])
249
- pt_lengths.append(min(len(text) + 1, self.max_length)) # one for end-token
250
- pt_scores = torch.stack(pt_scores)
251
- pt_lengths = pt_scores.new_tensor(pt_lengths, dtype=torch.long)
252
- return pt_text, pt_scores, pt_lengths
253
-
254
-
255
- class TopKTextAccuracy(TextAccuracy):
256
- _names = ['ccr', 'cwr']
257
- def __init__(self, k, charset_path, max_length, case_sensitive, model_eval):
258
- self.k = k
259
- self.charset_path = charset_path
260
- self.max_length = max_length
261
- self.case_sensitive = case_sensitive
262
- self.charset = CharsetMapper(charset_path, self.max_length)
263
- self.names = self._names
264
-
265
- def on_epoch_begin(self, **kwargs):
266
- self.total_num_char = 0.
267
- self.total_num_word = 0.
268
- self.correct_num_char = 0.
269
- self.correct_num_word = 0.
270
-
271
- def on_batch_end(self, last_output, last_target, **kwargs):
272
- logits, pt_lengths = last_output['logits'], last_output['pt_lengths']
273
- gt_labels, gt_lengths = last_target[:]
274
-
275
- for logit, pt_length, label, length in zip(logits, pt_lengths, gt_labels, gt_lengths):
276
- word_flag = True
277
- for i in range(length):
278
- char_logit = logit[i].topk(self.k)[1]
279
- char_label = label[i].argmax(-1)
280
- if char_label in char_logit: self.correct_num_char += 1
281
- else: word_flag = False
282
- self.total_num_char += 1
283
- if pt_length == length and word_flag:
284
- self.correct_num_word += 1
285
- self.total_num_word += 1
286
-
287
- def on_epoch_end(self, last_metrics, **kwargs):
288
- mets = [self.correct_num_char / self.total_num_char,
289
- self.correct_num_word / self.total_num_word,
290
- 0., 0., 0.]
291
- return add_metrics(last_metrics, mets)
292
-
293
-
294
- class DumpPrediction(LearnerCallback):
295
-
296
- def __init__(self, learn, dataset, charset_path, model_eval, image_only=False, debug=False):
297
- super().__init__(learn=learn)
298
- self.debug = debug
299
- self.model_eval = model_eval or 'alignment'
300
- self.image_only = image_only
301
- assert self.model_eval in ['vision', 'language', 'alignment']
302
-
303
- self.dataset, self.root = dataset, Path(self.learn.path)/f'{dataset}-{self.model_eval}'
304
- self.attn_root = self.root/'attn'
305
- self.charset = CharsetMapper(charset_path)
306
- if self.root.exists(): shutil.rmtree(self.root)
307
- self.root.mkdir(), self.attn_root.mkdir()
308
-
309
- self.pil = transforms.ToPILImage()
310
- self.tensor = transforms.ToTensor()
311
- size = self.learn.data.img_h, self.learn.data.img_w
312
- self.resize = transforms.Resize(size=size, interpolation=0)
313
- self.c = 0
314
-
315
- def on_batch_end(self, last_input, last_output, last_target, **kwargs):
316
- if isinstance(last_output, (tuple, list)):
317
- for res in last_output:
318
- if res['name'] == self.model_eval: pt_text = res['pt_text']
319
- if res['name'] == 'vision': attn_scores = res['attn_scores'].detach().cpu()
320
- if res['name'] == self.model_eval: logits = res['logits']
321
- else:
322
- pt_text = last_output['pt_text']
323
- attn_scores = last_output['attn_scores'].detach().cpu()
324
- logits = last_output['logits']
325
-
326
- images = last_input[0] if isinstance(last_input, (tuple, list)) else last_input
327
- images = images.detach().cpu()
328
- pt_text = [self.charset.trim(t) for t in pt_text]
329
- gt_label = last_target[0]
330
- if gt_label.dim() == 3: gt_label = gt_label.argmax(dim=-1) # one-hot label
331
- gt_text = [self.charset.get_text(l, trim=True) for l in gt_label]
332
-
333
- prediction, false_prediction = [], []
334
- for gt, pt, image, attn, logit in zip(gt_text, pt_text, images, attn_scores, logits):
335
- prediction.append(f'{gt}\t{pt}\n')
336
- if gt != pt:
337
- if self.debug:
338
- scores = torch.softmax(logit, dim=-1)[:max(len(pt), len(gt)) + 1]
339
- logging.info(f'{self.c} gt {gt}, pt {pt}, logit {logit.shape}, scores {scores.topk(5, dim=-1)}')
340
- false_prediction.append(f'{gt}\t{pt}\n')
341
-
342
- image = self.learn.data.denorm(image)
343
- if not self.image_only:
344
- image_np = np.array(self.pil(image))
345
- attn_pil = [self.pil(a) for a in attn[:, None, :, :]]
346
- attn = [self.tensor(self.resize(a)).repeat(3, 1, 1) for a in attn_pil]
347
- attn_sum = np.array([np.array(a) for a in attn_pil[:len(pt)]]).sum(axis=0)
348
- blended_sum = self.tensor(blend_mask(image_np, attn_sum))
349
- blended = [self.tensor(blend_mask(image_np, np.array(a))) for a in attn_pil]
350
- save_image = torch.stack([image] + attn + [blended_sum] + blended)
351
- save_image = save_image.view(2, -1, *save_image.shape[1:])
352
- save_image = save_image.permute(1, 0, 2, 3, 4).flatten(0, 1)
353
- vutils.save_image(save_image, self.attn_root/f'{self.c}_{gt}_{pt}.jpg',
354
- nrow=2, normalize=True, scale_each=True)
355
- else:
356
- self.pil(image).save(self.attn_root/f'{self.c}_{gt}_{pt}.jpg')
357
- self.c += 1
358
-
359
- with open(self.root/f'{self.model_eval}.txt', 'a') as f: f.writelines(prediction)
360
- with open(self.root/f'{self.model_eval}-false.txt', 'a') as f: f.writelines(false_prediction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset.py DELETED
@@ -1,278 +0,0 @@
1
- import logging
2
- import re
3
-
4
- import cv2
5
- import lmdb
6
- import six
7
- from fastai.vision import *
8
- from torchvision import transforms
9
-
10
- from transforms import CVColorJitter, CVDeterioration, CVGeometry
11
- from utils import CharsetMapper, onehot
12
-
13
-
14
- class ImageDataset(Dataset):
15
- "`ImageDataset` read data from LMDB database."
16
-
17
- def __init__(self,
18
- path:PathOrStr,
19
- is_training:bool=True,
20
- img_h:int=32,
21
- img_w:int=100,
22
- max_length:int=25,
23
- check_length:bool=True,
24
- case_sensitive:bool=False,
25
- charset_path:str='data/charset_36.txt',
26
- convert_mode:str='RGB',
27
- data_aug:bool=True,
28
- deteriorate_ratio:float=0.,
29
- multiscales:bool=True,
30
- one_hot_y:bool=True,
31
- return_idx:bool=False,
32
- return_raw:bool=False,
33
- **kwargs):
34
- self.path, self.name = Path(path), Path(path).name
35
- assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
36
- self.convert_mode, self.check_length = convert_mode, check_length
37
- self.img_h, self.img_w = img_h, img_w
38
- self.max_length, self.one_hot_y = max_length, one_hot_y
39
- self.return_idx, self.return_raw = return_idx, return_raw
40
- self.case_sensitive, self.is_training = case_sensitive, is_training
41
- self.data_aug, self.multiscales = data_aug, multiscales
42
- self.charset = CharsetMapper(charset_path, max_length=max_length+1)
43
- self.c = self.charset.num_classes
44
-
45
- self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
46
- assert self.env, f'Cannot open LMDB dataset from {path}.'
47
- with self.env.begin(write=False) as txn:
48
- self.length = int(txn.get('num-samples'.encode()))
49
-
50
- if self.is_training and self.data_aug:
51
- self.augment_tfs = transforms.Compose([
52
- CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
53
- CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
54
- CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
55
- ])
56
- self.totensor = transforms.ToTensor()
57
-
58
- def __len__(self): return self.length
59
-
60
- def _next_image(self, index):
61
- next_index = random.randint(0, len(self) - 1)
62
- return self.get(next_index)
63
-
64
- def _check_image(self, x, pixels=6):
65
- if x.size[0] <= pixels or x.size[1] <= pixels: return False
66
- else: return True
67
-
68
- def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT):
69
- def _resize_ratio(img, ratio, fix_h=True):
70
- if ratio * self.img_w < self.img_h:
71
- if fix_h: trg_h = self.img_h
72
- else: trg_h = int(ratio * self.img_w)
73
- trg_w = self.img_w
74
- else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
75
- img = cv2.resize(img, (trg_w, trg_h))
76
- pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
77
- top, bottom = math.ceil(pad_h), math.floor(pad_h)
78
- left, right = math.ceil(pad_w), math.floor(pad_w)
79
- img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
80
- return img
81
-
82
- if self.is_training:
83
- if random.random() < 0.5:
84
- base, maxh, maxw = self.img_h, self.img_h, self.img_w
85
- h, w = random.randint(base, maxh), random.randint(base, maxw)
86
- return _resize_ratio(img, h/w)
87
- else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
88
- else: return _resize_ratio(img, img.shape[0] / img.shape[1]) # keep aspect ratio
89
-
90
- def resize(self, img):
91
- if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
92
- else: return cv2.resize(img, (self.img_w, self.img_h))
93
-
94
- def get(self, idx):
95
- with self.env.begin(write=False) as txn:
96
- image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
97
- try:
98
- label = str(txn.get(label_key.encode()), 'utf-8') # label
99
- label = re.sub('[^0-9a-zA-Z]+', '', label)
100
- if self.check_length and self.max_length > 0:
101
- if len(label) > self.max_length or len(label) <= 0:
102
- #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
103
- return self._next_image(idx)
104
- label = label[:self.max_length]
105
-
106
- imgbuf = txn.get(image_key.encode()) # image
107
- buf = six.BytesIO()
108
- buf.write(imgbuf)
109
- buf.seek(0)
110
- with warnings.catch_warnings():
111
- warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
112
- image = PIL.Image.open(buf).convert(self.convert_mode)
113
- if self.is_training and not self._check_image(image):
114
- #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
115
- return self._next_image(idx)
116
- except:
117
- import traceback
118
- traceback.print_exc()
119
- logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
120
- return self._next_image(idx)
121
- return image, label, idx
122
-
123
- def _process_training(self, image):
124
- if self.data_aug: image = self.augment_tfs(image)
125
- image = self.resize(np.array(image))
126
- return image
127
-
128
- def _process_test(self, image):
129
- return self.resize(np.array(image)) # TODO:move is_training to here
130
-
131
- def __getitem__(self, idx):
132
- image, text, idx_new = self.get(idx)
133
- if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'
134
-
135
- if self.is_training: image = self._process_training(image)
136
- else: image = self._process_test(image)
137
- if self.return_raw: return image, text
138
- image = self.totensor(image)
139
-
140
- length = tensor(len(text) + 1).to(dtype=torch.long) # one for end token
141
- label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
142
- label = tensor(label).to(dtype=torch.long)
143
- if self.one_hot_y: label = onehot(label, self.charset.num_classes)
144
-
145
- if self.return_idx: y = [label, length, idx_new]
146
- else: y = [label, length]
147
- return image, y
148
-
149
-
150
- class TextDataset(Dataset):
151
- def __init__(self,
152
- path:PathOrStr,
153
- delimiter:str='\t',
154
- max_length:int=25,
155
- charset_path:str='data/charset_36.txt',
156
- case_sensitive=False,
157
- one_hot_x=True,
158
- one_hot_y=True,
159
- is_training=True,
160
- smooth_label=False,
161
- smooth_factor=0.2,
162
- use_sm=False,
163
- **kwargs):
164
- self.path = Path(path)
165
- self.case_sensitive, self.use_sm = case_sensitive, use_sm
166
- self.smooth_factor, self.smooth_label = smooth_factor, smooth_label
167
- self.charset = CharsetMapper(charset_path, max_length=max_length+1)
168
- self.one_hot_x, self.one_hot_y, self.is_training = one_hot_x, one_hot_y, is_training
169
- if self.is_training and self.use_sm: self.sm = SpellingMutation(charset=self.charset)
170
-
171
- dtype = {'inp': str, 'gt': str}
172
- self.df = pd.read_csv(self.path, dtype=dtype, delimiter=delimiter, na_filter=False)
173
- self.inp_col, self.gt_col = 0, 1
174
-
175
- def __len__(self): return len(self.df)
176
-
177
- def __getitem__(self, idx):
178
- text_x = self.df.iloc[idx, self.inp_col]
179
- text_x = re.sub('[^0-9a-zA-Z]+', '', text_x)
180
- if not self.case_sensitive: text_x = text_x.lower()
181
- if self.is_training and self.use_sm: text_x = self.sm(text_x)
182
-
183
- length_x = tensor(len(text_x) + 1).to(dtype=torch.long) # one for end token
184
- label_x = self.charset.get_labels(text_x, case_sensitive=self.case_sensitive)
185
- label_x = tensor(label_x)
186
- if self.one_hot_x:
187
- label_x = onehot(label_x, self.charset.num_classes)
188
- if self.is_training and self.smooth_label:
189
- label_x = torch.stack([self.prob_smooth_label(l) for l in label_x])
190
- x = [label_x, length_x]
191
-
192
- text_y = self.df.iloc[idx, self.gt_col]
193
- text_y = re.sub('[^0-9a-zA-Z]+', '', text_y)
194
- if not self.case_sensitive: text_y = text_y.lower()
195
- length_y = tensor(len(text_y) + 1).to(dtype=torch.long) # one for end token
196
- label_y = self.charset.get_labels(text_y, case_sensitive=self.case_sensitive)
197
- label_y = tensor(label_y)
198
- if self.one_hot_y: label_y = onehot(label_y, self.charset.num_classes)
199
- y = [label_y, length_y]
200
-
201
- return x, y
202
-
203
- def prob_smooth_label(self, one_hot):
204
- one_hot = one_hot.float()
205
- delta = torch.rand([]) * self.smooth_factor
206
- num_classes = len(one_hot)
207
- noise = torch.rand(num_classes)
208
- noise = noise / noise.sum() * delta
209
- one_hot = one_hot * (1 - delta) + noise
210
- return one_hot
211
-
212
-
213
- class SpellingMutation(object):
214
- def __init__(self, pn0=0.7, pn1=0.85, pn2=0.95, pt0=0.7, pt1=0.85, charset=None):
215
- """
216
- Args:
217
- pn0: the prob of not modifying characters is (pn0)
218
- pn1: the prob of modifying one characters is (pn1 - pn0)
219
- pn2: the prob of modifying two characters is (pn2 - pn1),
220
- and three (1 - pn2)
221
- pt0: the prob of replacing operation is pt0.
222
- pt1: the prob of inserting operation is (pt1 - pt0),
223
- and deleting operation is (1 - pt1)
224
- """
225
- super().__init__()
226
- self.pn0, self.pn1, self.pn2 = pn0, pn1, pn2
227
- self.pt0, self.pt1 = pt0, pt1
228
- self.charset = charset
229
- logging.info(f'the probs: pn0={self.pn0}, pn1={self.pn1} ' +
230
- f'pn2={self.pn2}, pt0={self.pt0}, pt1={self.pt1}')
231
-
232
- def is_digit(self, text, ratio=0.5):
233
- length = max(len(text), 1)
234
- digit_num = sum([t in self.charset.digits for t in text])
235
- if digit_num / length < ratio: return False
236
- return True
237
-
238
- def is_unk_char(self, char):
239
- # return char == self.charset.unk_char
240
- return (char not in self.charset.digits) and (char not in self.charset.alphabets)
241
-
242
- def get_num_to_modify(self, length):
243
- prob = random.random()
244
- if prob < self.pn0: num_to_modify = 0
245
- elif prob < self.pn1: num_to_modify = 1
246
- elif prob < self.pn2: num_to_modify = 2
247
- else: num_to_modify = 3
248
-
249
- if length <= 1: num_to_modify = 0
250
- elif length >= 2 and length <= 4: num_to_modify = min(num_to_modify, 1)
251
- else: num_to_modify = min(num_to_modify, length // 2) # smaller than length // 2
252
- return num_to_modify
253
-
254
- def __call__(self, text, debug=False):
255
- if self.is_digit(text): return text
256
- length = len(text)
257
- num_to_modify = self.get_num_to_modify(length)
258
- if num_to_modify <= 0: return text
259
-
260
- chars = []
261
- index = np.arange(0, length)
262
- random.shuffle(index)
263
- index = index[: num_to_modify]
264
- if debug: self.index = index
265
- for i, t in enumerate(text):
266
- if i not in index: chars.append(t)
267
- elif self.is_unk_char(t): chars.append(t)
268
- else:
269
- prob = random.random()
270
- if prob < self.pt0: # replace
271
- chars.append(random.choice(self.charset.alphabets))
272
- elif prob < self.pt1: # insert
273
- chars.append(random.choice(self.charset.alphabets))
274
- chars.append(t)
275
- else: # delete
276
- continue
277
- new_text = ''.join(chars[: self.charset.max_length-1])
278
- return new_text if len(new_text) >= 1 else text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
losses.py DELETED
@@ -1,72 +0,0 @@
1
- from fastai.vision import *
2
-
3
- from modules.model import Model
4
-
5
-
6
- class MultiLosses(nn.Module):
7
- def __init__(self, one_hot=True):
8
- super().__init__()
9
- self.ce = SoftCrossEntropyLoss() if one_hot else torch.nn.CrossEntropyLoss()
10
- self.bce = torch.nn.BCELoss()
11
-
12
- @property
13
- def last_losses(self):
14
- return self.losses
15
-
16
- def _flatten(self, sources, lengths):
17
- return torch.cat([t[:l] for t, l in zip(sources, lengths)])
18
-
19
- def _merge_list(self, all_res):
20
- if not isinstance(all_res, (list, tuple)):
21
- return all_res
22
- def merge(items):
23
- if isinstance(items[0], torch.Tensor): return torch.cat(items, dim=0)
24
- else: return items[0]
25
- res = dict()
26
- for key in all_res[0].keys():
27
- items = [r[key] for r in all_res]
28
- res[key] = merge(items)
29
- return res
30
-
31
- def _ce_loss(self, output, gt_labels, gt_lengths, idx=None, record=True):
32
- loss_name = output.get('name')
33
- pt_logits, weight = output['logits'], output['loss_weight']
34
-
35
- assert pt_logits.shape[0] % gt_labels.shape[0] == 0
36
- iter_size = pt_logits.shape[0] // gt_labels.shape[0]
37
- if iter_size > 1:
38
- gt_labels = gt_labels.repeat(3, 1, 1)
39
- gt_lengths = gt_lengths.repeat(3)
40
- flat_gt_labels = self._flatten(gt_labels, gt_lengths)
41
- flat_pt_logits = self._flatten(pt_logits, gt_lengths)
42
-
43
- nll = output.get('nll')
44
- if nll is not None:
45
- loss = self.ce(flat_pt_logits, flat_gt_labels, softmax=False) * weight
46
- else:
47
- loss = self.ce(flat_pt_logits, flat_gt_labels) * weight
48
- if record and loss_name is not None: self.losses[f'{loss_name}_loss'] = loss
49
-
50
- return loss
51
-
52
- def forward(self, outputs, *args):
53
- self.losses = {}
54
- if isinstance(outputs, (tuple, list)):
55
- outputs = [self._merge_list(o) for o in outputs]
56
- return sum([self._ce_loss(o, *args) for o in outputs if o['loss_weight'] > 0.])
57
- else:
58
- return self._ce_loss(outputs, *args, record=False)
59
-
60
-
61
- class SoftCrossEntropyLoss(nn.Module):
62
- def __init__(self, reduction="mean"):
63
- super().__init__()
64
- self.reduction = reduction
65
-
66
- def forward(self, input, target, softmax=True):
67
- if softmax: log_prob = F.log_softmax(input, dim=-1)
68
- else: log_prob = torch.log(input)
69
- loss = -(target * log_prob).sum(dim=-1)
70
- if self.reduction == "mean": return loss.mean()
71
- elif self.reduction == "sum": return loss.sum()
72
- else: return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,246 +0,0 @@
1
- import argparse
2
- import logging
3
- import os
4
- import random
5
-
6
- import torch
7
- from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase
8
- from fastai.distributed import *
9
- from fastai.vision import *
10
- from torch.backends import cudnn
11
-
12
- from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy
13
- from dataset import ImageDataset, TextDataset
14
- from losses import MultiLosses
15
- from utils import Config, Logger, MyDataParallel, MyConcatDataset
16
-
17
-
18
- def _set_random_seed(seed):
19
- if seed is not None:
20
- random.seed(seed)
21
- torch.manual_seed(seed)
22
- cudnn.deterministic = True
23
- logging.warning('You have chosen to seed training. '
24
- 'This will slow down your training!')
25
-
26
- def _get_training_phases(config, n):
27
- lr = np.array(config.optimizer_lr)
28
- periods = config.optimizer_scheduler_periods
29
- sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))]
30
- phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i])
31
- for i in range(len(periods))]
32
- return phases
33
-
34
- def _get_dataset(ds_type, paths, is_training, config, **kwargs):
35
- kwargs.update({
36
- 'img_h': config.dataset_image_height,
37
- 'img_w': config.dataset_image_width,
38
- 'max_length': config.dataset_max_length,
39
- 'case_sensitive': config.dataset_case_sensitive,
40
- 'charset_path': config.dataset_charset_path,
41
- 'data_aug': config.dataset_data_aug,
42
- 'deteriorate_ratio': config.dataset_deteriorate_ratio,
43
- 'is_training': is_training,
44
- 'multiscales': config.dataset_multiscales,
45
- 'one_hot_y': config.dataset_one_hot_y,
46
- })
47
- datasets = [ds_type(p, **kwargs) for p in paths]
48
- if len(datasets) > 1: return MyConcatDataset(datasets)
49
- else: return datasets[0]
50
-
51
-
52
- def _get_language_databaunch(config):
53
- kwargs = {
54
- 'max_length': config.dataset_max_length,
55
- 'case_sensitive': config.dataset_case_sensitive,
56
- 'charset_path': config.dataset_charset_path,
57
- 'smooth_label': config.dataset_smooth_label,
58
- 'smooth_factor': config.dataset_smooth_factor,
59
- 'one_hot_y': config.dataset_one_hot_y,
60
- 'use_sm': config.dataset_use_sm,
61
- }
62
- train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs)
63
- valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs)
64
- data = DataBunch.create(
65
- path=train_ds.path,
66
- train_ds=train_ds,
67
- valid_ds=valid_ds,
68
- bs=config.dataset_train_batch_size,
69
- val_bs=config.dataset_test_batch_size,
70
- num_workers=config.dataset_num_workers,
71
- pin_memory=config.dataset_pin_memory)
72
- logging.info(f'{len(data.train_ds)} training items found.')
73
- if not data.empty_val:
74
- logging.info(f'{len(data.valid_ds)} valid items found.')
75
- return data
76
-
77
- def _get_databaunch(config):
78
- # An awkward way to reduce loadding data time during test
79
- if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots
80
- train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config)
81
- valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config)
82
- data = ImageDataBunch.create(
83
- train_ds=train_ds,
84
- valid_ds=valid_ds,
85
- bs=config.dataset_train_batch_size,
86
- val_bs=config.dataset_test_batch_size,
87
- num_workers=config.dataset_num_workers,
88
- pin_memory=config.dataset_pin_memory).normalize(imagenet_stats)
89
- ar_tfm = lambda x: ((x[0], x[1]), x[1]) # auto-regression only for dtd
90
- data.add_tfm(ar_tfm)
91
-
92
- logging.info(f'{len(data.train_ds)} training items found.')
93
- if not data.empty_val:
94
- logging.info(f'{len(data.valid_ds)} valid items found.')
95
-
96
- return data
97
-
98
- def _get_model(config):
99
- import importlib
100
- names = config.model_name.split('.')
101
- module_name, class_name = '.'.join(names[:-1]), names[-1]
102
- cls = getattr(importlib.import_module(module_name), class_name)
103
- model = cls(config)
104
- logging.info(model)
105
- return model
106
-
107
-
108
- def _get_learner(config, data, model, local_rank=None):
109
- strict = ifnone(config.model_strict, True)
110
- if config.global_stage == 'pretrain-language':
111
- metrics = [TopKTextAccuracy(
112
- k=ifnone(config.model_k, 5),
113
- charset_path=config.dataset_charset_path,
114
- max_length=config.dataset_max_length + 1,
115
- case_sensitive=config.dataset_eval_case_sensisitves,
116
- model_eval=config.model_eval)]
117
- else:
118
- metrics = [TextAccuracy(
119
- charset_path=config.dataset_charset_path,
120
- max_length=config.dataset_max_length + 1,
121
- case_sensitive=config.dataset_eval_case_sensisitves,
122
- model_eval=config.model_eval)]
123
- opt_type = getattr(torch.optim, config.optimizer_type)
124
- learner = Learner(data, model, silent=True, model_dir='.',
125
- true_wd=config.optimizer_true_wd,
126
- wd=config.optimizer_wd,
127
- bn_wd=config.optimizer_bn_wd,
128
- path=config.global_workdir,
129
- metrics=metrics,
130
- opt_func=partial(opt_type, **config.optimizer_args or dict()),
131
- loss_func=MultiLosses(one_hot=config.dataset_one_hot_y))
132
- learner.split(lambda m: children(m))
133
-
134
- if config.global_phase == 'train':
135
- num_replicas = 1 if local_rank is None else torch.distributed.get_world_size()
136
- phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas)
137
- learner.callback_fns += [
138
- partial(GeneralScheduler, phases=phases),
139
- partial(GradientClipping, clip=config.optimizer_clip_grad),
140
- partial(IterationCallback, name=config.global_name,
141
- show_iters=config.training_show_iters,
142
- eval_iters=config.training_eval_iters,
143
- save_iters=config.training_save_iters,
144
- start_iters=config.training_start_iters,
145
- stats_iters=config.training_stats_iters)]
146
- else:
147
- learner.callbacks += [
148
- DumpPrediction(learn=learner,
149
- dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path,
150
- model_eval=config.model_eval,
151
- debug=config.global_debug,
152
- image_only=config.global_image_only)]
153
-
154
- learner.rank = local_rank
155
- if local_rank is not None:
156
- logging.info(f'Set model to distributed with rank {local_rank}.')
157
- learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model)
158
- learner.model.to(local_rank)
159
- learner = learner.to_distributed(local_rank)
160
-
161
- if torch.cuda.device_count() > 1 and local_rank is None:
162
- logging.info(f'Use {torch.cuda.device_count()} GPUs.')
163
- learner.model = MyDataParallel(learner.model)
164
-
165
- if config.model_checkpoint:
166
- if Path(config.model_checkpoint).exists():
167
- with open(config.model_checkpoint, 'rb') as f:
168
- buffer = io.BytesIO(f.read())
169
- learner.load(buffer, strict=strict)
170
- else:
171
- from distutils.dir_util import copy_tree
172
- src = Path('/data/fangsc/model')/config.global_name
173
- trg = Path('/output')/config.global_name
174
- if src.exists(): copy_tree(str(src), str(trg))
175
- learner.load(config.model_checkpoint, strict=strict)
176
- logging.info(f'Read model from {config.model_checkpoint}')
177
- elif config.global_phase == 'test':
178
- learner.load(f'best-{config.global_name}', strict=strict)
179
- logging.info(f'Read model from best-{config.global_name}')
180
-
181
- if learner.opt_func.func.__name__ == 'Adadelta': # fastai bug, fix after 1.0.60
182
- learner.fit(epochs=0, lr=config.optimizer_lr)
183
- learner.opt.mom = 0.
184
-
185
- return learner
186
-
187
- def main():
188
- parser = argparse.ArgumentParser()
189
- parser.add_argument('--config', type=str, required=True,
190
- help='path to config file')
191
- parser.add_argument('--phase', type=str, default=None, choices=['train', 'test'])
192
- parser.add_argument('--name', type=str, default=None)
193
- parser.add_argument('--checkpoint', type=str, default=None)
194
- parser.add_argument('--test_root', type=str, default=None)
195
- parser.add_argument("--local_rank", type=int, default=None)
196
- parser.add_argument('--debug', action='store_true', default=None)
197
- parser.add_argument('--image_only', action='store_true', default=None)
198
- parser.add_argument('--model_strict', action='store_false', default=None)
199
- parser.add_argument('--model_eval', type=str, default=None,
200
- choices=['alignment', 'vision', 'language'])
201
- args = parser.parse_args()
202
- config = Config(args.config)
203
- if args.name is not None: config.global_name = args.name
204
- if args.phase is not None: config.global_phase = args.phase
205
- if args.test_root is not None: config.dataset_test_roots = [args.test_root]
206
- if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
207
- if args.debug is not None: config.global_debug = args.debug
208
- if args.image_only is not None: config.global_image_only = args.image_only
209
- if args.model_eval is not None: config.model_eval = args.model_eval
210
- if args.model_strict is not None: config.model_strict = args.model_strict
211
-
212
- Logger.init(config.global_workdir, config.global_name, config.global_phase)
213
- Logger.enable_file()
214
- _set_random_seed(config.global_seed)
215
- logging.info(config)
216
-
217
- if args.local_rank is not None:
218
- logging.info(f'Init distribution training at device {args.local_rank}.')
219
- torch.cuda.set_device(args.local_rank)
220
- torch.distributed.init_process_group(backend='nccl', init_method='env://')
221
-
222
- logging.info('Construct dataset.')
223
- if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config)
224
- else: data = _get_databaunch(config)
225
-
226
- logging.info('Construct model.')
227
- model = _get_model(config)
228
-
229
- logging.info('Construct learner.')
230
- learner = _get_learner(config, data, model, args.local_rank)
231
-
232
- if config.global_phase == 'train':
233
- logging.info('Start training.')
234
- learner.fit(epochs=config.training_epochs,
235
- lr=config.optimizer_lr)
236
- else:
237
- logging.info('Start validate')
238
- last_metrics = learner.validate()
239
- log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \
240
- f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \
241
- f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \
242
- f'ted/w = {last_metrics[5]:6.3f}, '
243
- logging.info(log_str)
244
-
245
- if __name__ == '__main__':
246
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transforms.py DELETED
@@ -1,329 +0,0 @@
1
- import math
2
- import numbers
3
- import random
4
-
5
- import cv2
6
- import numpy as np
7
- from PIL import Image
8
- from torchvision import transforms
9
- from torchvision.transforms import Compose
10
-
11
-
12
- def sample_asym(magnitude, size=None):
13
- return np.random.beta(1, 4, size) * magnitude
14
-
15
- def sample_sym(magnitude, size=None):
16
- return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
17
-
18
- def sample_uniform(low, high, size=None):
19
- return np.random.uniform(low, high, size=size)
20
-
21
- def get_interpolation(type='random'):
22
- if type == 'random':
23
- choice = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA]
24
- interpolation = choice[random.randint(0, len(choice)-1)]
25
- elif type == 'nearest': interpolation = cv2.INTER_NEAREST
26
- elif type == 'linear': interpolation = cv2.INTER_LINEAR
27
- elif type == 'cubic': interpolation = cv2.INTER_CUBIC
28
- elif type == 'area': interpolation = cv2.INTER_AREA
29
- else: raise TypeError('Interpolation types only nearest, linear, cubic, area are supported!')
30
- return interpolation
31
-
32
- class CVRandomRotation(object):
33
- def __init__(self, degrees=15):
34
- assert isinstance(degrees, numbers.Number), "degree should be a single number."
35
- assert degrees >= 0, "degree must be positive."
36
- self.degrees = degrees
37
-
38
- @staticmethod
39
- def get_params(degrees):
40
- return sample_sym(degrees)
41
-
42
- def __call__(self, img):
43
- angle = self.get_params(self.degrees)
44
- src_h, src_w = img.shape[:2]
45
- M = cv2.getRotationMatrix2D(center=(src_w/2, src_h/2), angle=angle, scale=1.0)
46
- abs_cos, abs_sin = abs(M[0,0]), abs(M[0,1])
47
- dst_w = int(src_h * abs_sin + src_w * abs_cos)
48
- dst_h = int(src_h * abs_cos + src_w * abs_sin)
49
- M[0, 2] += (dst_w - src_w)/2
50
- M[1, 2] += (dst_h - src_h)/2
51
-
52
- flags = get_interpolation()
53
- return cv2.warpAffine(img, M, (dst_w, dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
54
-
55
- class CVRandomAffine(object):
56
- def __init__(self, degrees, translate=None, scale=None, shear=None):
57
- assert isinstance(degrees, numbers.Number), "degree should be a single number."
58
- assert degrees >= 0, "degree must be positive."
59
- self.degrees = degrees
60
-
61
- if translate is not None:
62
- assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
63
- "translate should be a list or tuple and it must be of length 2."
64
- for t in translate:
65
- if not (0.0 <= t <= 1.0):
66
- raise ValueError("translation values should be between 0 and 1")
67
- self.translate = translate
68
-
69
- if scale is not None:
70
- assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
71
- "scale should be a list or tuple and it must be of length 2."
72
- for s in scale:
73
- if s <= 0:
74
- raise ValueError("scale values should be positive")
75
- self.scale = scale
76
-
77
- if shear is not None:
78
- if isinstance(shear, numbers.Number):
79
- if shear < 0:
80
- raise ValueError("If shear is a single number, it must be positive.")
81
- self.shear = [shear]
82
- else:
83
- assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
84
- "shear should be a list or tuple and it must be of length 2."
85
- self.shear = shear
86
- else:
87
- self.shear = shear
88
-
89
- def _get_inverse_affine_matrix(self, center, angle, translate, scale, shear):
90
- # https://github.com/pytorch/vision/blob/v0.4.0/torchvision/transforms/functional.py#L717
91
- from numpy import sin, cos, tan
92
-
93
- if isinstance(shear, numbers.Number):
94
- shear = [shear, 0]
95
-
96
- if not isinstance(shear, (tuple, list)) and len(shear) == 2:
97
- raise ValueError(
98
- "Shear should be a single value or a tuple/list containing " +
99
- "two values. Got {}".format(shear))
100
-
101
- rot = math.radians(angle)
102
- sx, sy = [math.radians(s) for s in shear]
103
-
104
- cx, cy = center
105
- tx, ty = translate
106
-
107
- # RSS without scaling
108
- a = cos(rot - sy) / cos(sy)
109
- b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
110
- c = sin(rot - sy) / cos(sy)
111
- d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
112
-
113
- # Inverted rotation matrix with scale and shear
114
- # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
115
- M = [d, -b, 0,
116
- -c, a, 0]
117
- M = [x / scale for x in M]
118
-
119
- # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
120
- M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
121
- M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
122
-
123
- # Apply center translation: C * RSS^-1 * C^-1 * T^-1
124
- M[2] += cx
125
- M[5] += cy
126
- return M
127
-
128
- @staticmethod
129
- def get_params(degrees, translate, scale_ranges, shears, height):
130
- angle = sample_sym(degrees)
131
- if translate is not None:
132
- max_dx = translate[0] * height
133
- max_dy = translate[1] * height
134
- translations = (np.round(sample_sym(max_dx)), np.round(sample_sym(max_dy)))
135
- else:
136
- translations = (0, 0)
137
-
138
- if scale_ranges is not None:
139
- scale = sample_uniform(scale_ranges[0], scale_ranges[1])
140
- else:
141
- scale = 1.0
142
-
143
- if shears is not None:
144
- if len(shears) == 1:
145
- shear = [sample_sym(shears[0]), 0.]
146
- elif len(shears) == 2:
147
- shear = [sample_sym(shears[0]), sample_sym(shears[1])]
148
- else:
149
- shear = 0.0
150
-
151
- return angle, translations, scale, shear
152
-
153
-
154
- def __call__(self, img):
155
- src_h, src_w = img.shape[:2]
156
- angle, translate, scale, shear = self.get_params(
157
- self.degrees, self.translate, self.scale, self.shear, src_h)
158
-
159
- M = self._get_inverse_affine_matrix((src_w/2, src_h/2), angle, (0, 0), scale, shear)
160
- M = np.array(M).reshape(2,3)
161
-
162
- startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1), (0, src_h - 1)]
163
- project = lambda x, y, a, b, c: int(a*x + b*y + c)
164
- endpoints = [(project(x, y, *M[0]), project(x, y, *M[1])) for x, y in startpoints]
165
-
166
- rect = cv2.minAreaRect(np.array(endpoints))
167
- bbox = cv2.boxPoints(rect).astype(dtype=np.int)
168
- max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
169
- min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
170
-
171
- dst_w = int(max_x - min_x)
172
- dst_h = int(max_y - min_y)
173
- M[0, 2] += (dst_w - src_w) / 2
174
- M[1, 2] += (dst_h - src_h) / 2
175
-
176
- # add translate
177
- dst_w += int(abs(translate[0]))
178
- dst_h += int(abs(translate[1]))
179
- if translate[0] < 0: M[0, 2] += abs(translate[0])
180
- if translate[1] < 0: M[1, 2] += abs(translate[1])
181
-
182
- flags = get_interpolation()
183
- return cv2.warpAffine(img, M, (dst_w , dst_h), flags=flags, borderMode=cv2.BORDER_REPLICATE)
184
-
185
- class CVRandomPerspective(object):
186
- def __init__(self, distortion=0.5):
187
- self.distortion = distortion
188
-
189
- def get_params(self, width, height, distortion):
190
- offset_h = sample_asym(distortion * height / 2, size=4).astype(dtype=np.int)
191
- offset_w = sample_asym(distortion * width / 2, size=4).astype(dtype=np.int)
192
- topleft = ( offset_w[0], offset_h[0])
193
- topright = (width - 1 - offset_w[1], offset_h[1])
194
- botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
195
- botleft = ( offset_w[3], height - 1 - offset_h[3])
196
-
197
- startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
198
- endpoints = [topleft, topright, botright, botleft]
199
- return np.array(startpoints, dtype=np.float32), np.array(endpoints, dtype=np.float32)
200
-
201
- def __call__(self, img):
202
- height, width = img.shape[:2]
203
- startpoints, endpoints = self.get_params(width, height, self.distortion)
204
- M = cv2.getPerspectiveTransform(startpoints, endpoints)
205
-
206
- # TODO: more robust way to crop image
207
- rect = cv2.minAreaRect(endpoints)
208
- bbox = cv2.boxPoints(rect).astype(dtype=np.int)
209
- max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
210
- min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
211
- min_x, min_y = max(min_x, 0), max(min_y, 0)
212
-
213
- flags = get_interpolation()
214
- img = cv2.warpPerspective(img, M, (max_x, max_y), flags=flags, borderMode=cv2.BORDER_REPLICATE)
215
- img = img[min_y:, min_x:]
216
- return img
217
-
218
- class CVRescale(object):
219
-
220
- def __init__(self, factor=4, base_size=(128, 512)):
221
- """ Define image scales using gaussian pyramid and rescale image to target scale.
222
-
223
- Args:
224
- factor: the decayed factor from base size, factor=4 keeps target scale by default.
225
- base_size: base size the build the bottom layer of pyramid
226
- """
227
- if isinstance(factor, numbers.Number):
228
- self.factor = round(sample_uniform(0, factor))
229
- elif isinstance(factor, (tuple, list)) and len(factor) == 2:
230
- self.factor = round(sample_uniform(factor[0], factor[1]))
231
- else:
232
- raise Exception('factor must be number or list with length 2')
233
- # assert factor is valid
234
- self.base_h, self.base_w = base_size[:2]
235
-
236
- def __call__(self, img):
237
- if self.factor == 0: return img
238
- src_h, src_w = img.shape[:2]
239
- cur_w, cur_h = self.base_w, self.base_h
240
- scale_img = cv2.resize(img, (cur_w, cur_h), interpolation=get_interpolation())
241
- for _ in range(self.factor):
242
- scale_img = cv2.pyrDown(scale_img)
243
- scale_img = cv2.resize(scale_img, (src_w, src_h), interpolation=get_interpolation())
244
- return scale_img
245
-
246
- class CVGaussianNoise(object):
247
- def __init__(self, mean=0, var=20):
248
- self.mean = mean
249
- if isinstance(var, numbers.Number):
250
- self.var = max(int(sample_asym(var)), 1)
251
- elif isinstance(var, (tuple, list)) and len(var) == 2:
252
- self.var = int(sample_uniform(var[0], var[1]))
253
- else:
254
- raise Exception('degree must be number or list with length 2')
255
-
256
- def __call__(self, img):
257
- noise = np.random.normal(self.mean, self.var**0.5, img.shape)
258
- img = np.clip(img + noise, 0, 255).astype(np.uint8)
259
- return img
260
-
261
- class CVMotionBlur(object):
262
- def __init__(self, degrees=12, angle=90):
263
- if isinstance(degrees, numbers.Number):
264
- self.degree = max(int(sample_asym(degrees)), 1)
265
- elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
266
- self.degree = int(sample_uniform(degrees[0], degrees[1]))
267
- else:
268
- raise Exception('degree must be number or list with length 2')
269
- self.angle = sample_uniform(-angle, angle)
270
-
271
- def __call__(self, img):
272
- M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2), self.angle, 1)
273
- motion_blur_kernel = np.zeros((self.degree, self.degree))
274
- motion_blur_kernel[self.degree // 2, :] = 1
275
- motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree))
276
- motion_blur_kernel = motion_blur_kernel / self.degree
277
- img = cv2.filter2D(img, -1, motion_blur_kernel)
278
- img = np.clip(img, 0, 255).astype(np.uint8)
279
- return img
280
-
281
- class CVGeometry(object):
282
- def __init__(self, degrees=15, translate=(0.3, 0.3), scale=(0.5, 2.),
283
- shear=(45, 15), distortion=0.5, p=0.5):
284
- self.p = p
285
- type_p = random.random()
286
- if type_p < 0.33:
287
- self.transforms = CVRandomRotation(degrees=degrees)
288
- elif type_p < 0.66:
289
- self.transforms = CVRandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
290
- else:
291
- self.transforms = CVRandomPerspective(distortion=distortion)
292
-
293
- def __call__(self, img):
294
- if random.random() < self.p:
295
- img = np.array(img)
296
- return Image.fromarray(self.transforms(img))
297
- else: return img
298
-
299
- class CVDeterioration(object):
300
- def __init__(self, var, degrees, factor, p=0.5):
301
- self.p = p
302
- transforms = []
303
- if var is not None:
304
- transforms.append(CVGaussianNoise(var=var))
305
- if degrees is not None:
306
- transforms.append(CVMotionBlur(degrees=degrees))
307
- if factor is not None:
308
- transforms.append(CVRescale(factor=factor))
309
-
310
- random.shuffle(transforms)
311
- transforms = Compose(transforms)
312
- self.transforms = transforms
313
-
314
- def __call__(self, img):
315
- if random.random() < self.p:
316
- img = np.array(img)
317
- return Image.fromarray(self.transforms(img))
318
- else: return img
319
-
320
-
321
- class CVColorJitter(object):
322
- def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5):
323
- self.p = p
324
- self.transforms = transforms.ColorJitter(brightness=brightness, contrast=contrast,
325
- saturation=saturation, hue=hue)
326
-
327
- def __call__(self, img):
328
- if random.random() < self.p: return self.transforms(img)
329
- else: return img