Niv Sardi commited on
Commit
c92a751
1 Parent(s): f0a5526

augment: accept command line arguments

Browse files
Files changed (1) hide show
  1. python/augment.py +140 -113
python/augment.py CHANGED
@@ -12,6 +12,8 @@ import cv2
12
  import filetype
13
  from filetype.match import image_matchers
14
 
 
 
15
  import imgaug as ia
16
  from imgaug import augmenters as iaa
17
  from imgaug.augmentables.batches import UnnormalizedBatch
@@ -23,145 +25,170 @@ import pipelines
23
 
24
  BATCH_SIZE = 16
25
 
26
- mkdir.make_dirs([defaults.AUGMENTED_IMAGES_PATH, defaults.AUGMENTED_LABELS_PATH])
27
-
28
- logo_images = []
29
- logo_alphas = []
30
- logo_labels = {}
31
-
32
- db = {}
33
- with open(defaults.MAIN_CSV_PATH, 'r') as f:
34
- reader = csv.DictReader(f)
35
- db = {e.bco: e for e in [Entity.from_dict(d) for d in reader]}
36
 
37
- background_images = [d for d in os.scandir(defaults.IMAGES_PATH)]
 
38
 
39
- stats = {
40
- 'failed': 0,
41
- 'ok': 0
42
- }
43
 
44
- for d in os.scandir(defaults.LOGOS_DATA_PATH):
45
- img = None
46
- if not d.is_file():
47
- stats['failed'] += 1
48
- continue
49
 
50
- try:
51
- if filetype.match(d.path, matchers=image_matchers):
52
- img = cv2.imread(d.path, cv2.IMREAD_UNCHANGED)
53
- else:
54
- png = svg2png(url=d.path)
55
- img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
56
- label = db[d.name.split('.')[0]].id
57
 
58
- (h, w, c) = img.shape
59
- if c == 3:
60
- img = imtool.add_alpha(img)
 
61
 
62
- if img.ndim < 3:
63
- print(f'very bad dim: {img.ndim}')
 
 
 
64
 
65
- img = imtool.remove_white(img)
66
- (h, w, c) = img.shape
 
 
 
 
 
67
 
68
- assert(w > 10)
69
- assert(h > 10)
 
70
 
71
- stats['ok'] += 1
 
72
 
73
- (b, g, r, _) = cv2.split(img)
74
- alpha = img[:, :, 3]/255
75
- d = cv2.merge([b, g, r])
76
 
77
- logo_images.append(d)
78
- # tried id() tried __array_interface__, tried tagging, nothing works
79
- logo_labels.update({d.tobytes(): label})
80
 
81
- # XXX(xaiki): we pass alpha as a float32 heatmap,
82
- # because imgaug is pretty strict about what data it will process
83
- # and that we want the alpha layer to pass the same transformations as the orig
84
- logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
85
 
86
- except Exception as e:
87
- stats['failed'] += 1
88
- print(f'error loading: {d.path}: {e}')
89
 
90
- print(stats)
91
- #print(len(logo_alphas), len(logo_images), len(logo_labels))
92
- assert(len(logo_alphas) == len(logo_images))
93
 
94
- # so that we don't get a lot of the same logos on the same page.
95
- zipped = list(zip(logo_images, logo_alphas))
96
- random.shuffle(zipped)
97
- logo_images, logo_alphas = zip(*zipped)
98
 
99
- n = len(logo_images)
100
- batches = []
101
- for i in range(math.floor(n*2/BATCH_SIZE)):
102
- s = (i*BATCH_SIZE)%n
103
- e = min(s + BATCH_SIZE, n)
104
- le = max(0, BATCH_SIZE - (e - s))
105
 
106
- a = logo_images[0:le] + logo_images[s:e]
107
- h = logo_alphas[0:le] + logo_alphas[s:e]
 
108
 
109
- assert(len(a) == BATCH_SIZE)
 
 
 
110
 
111
- batches.append(UnnormalizedBatch(images=a,heatmaps=h))
 
 
 
 
 
112
 
113
- # We use a single, very fast augmenter here to show that batches
114
- # are only loaded once there is space again in the buffer.
115
- pipeline = pipelines.HUGE
116
 
117
- def create_generator(lst):
118
- for b in lst:
119
- print(f"Loading next unaugmented batch...")
120
- yield b
121
 
122
- batches_generator = create_generator(batches)
123
 
124
- with pipeline.pool(processes=-1, seed=1) as pool:
125
- batches_aug = pool.imap_batches(batches_generator, output_buffer_size=5)
 
 
126
 
127
- print(f"Requesting next augmented batch...")
128
- for i, batch_aug in enumerate(batches_aug):
129
- idx = list(range(len(batch_aug.images_aug)))
130
- random.shuffle(idx)
131
- for j, d in enumerate(background_images):
132
- img = imtool.remove_white(cv2.imread(d.path))
133
- basename = d.name.replace('.png', '') + f'.{i}.{j}'
 
 
 
 
 
 
 
 
 
 
134
 
135
- anotations = []
136
- for k in range(math.floor(len(batch_aug.images_aug)/3)):
137
- logo_idx = (j+k*4)%len(batch_aug.images_aug)
 
 
 
 
138
 
139
- orig = batch_aug.images_unaug[logo_idx]
140
- label = logo_labels[orig.tobytes()]
141
- logo = batch_aug.images_aug[logo_idx]
 
 
142
 
143
- assert(logo.shape == orig.shape)
144
-
145
- # XXX(xaiki): we get alpha from heatmap, but will only use one channel
146
- # we could make mix_alpha into mix_mask and pass all 3 chanels
147
- alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
 
 
148
 
149
  try:
150
- bb = imtool.mix_alpha(img, logo, alpha[0],
151
- random.random(), random.random())
152
- c = bb.to_centroid(img.shape)
153
- anotations.append(c.to_anotation(label))
154
- except AssertionError as e:
155
- print(f'couldnt process {i}, {j}: {e}')
156
-
157
- try:
158
- cv2.imwrite(f'{defaults.AUGMENTED_IMAGES_PATH}/{basename}.png', img)
159
- label_path = f"{defaults.AUGMENTED_LABELS_PATH}/{basename}.txt"
160
- with open(label_path, 'a') as f:
161
- f.write('\n'.join(anotations))
162
- except Exception:
163
- print(f'couldnt write image {basename}')
164
-
165
- if i < len(batches)-1:
166
- print("Requesting next augmented batch...")
167
-
 
 
 
 
 
 
 
 
 
 
 
 
12
  import filetype
13
  from filetype.match import image_matchers
14
 
15
+ from progress.bar import ChargingBar
16
+
17
  import imgaug as ia
18
  from imgaug import augmenters as iaa
19
  from imgaug.augmentables.batches import UnnormalizedBatch
 
25
 
26
  BATCH_SIZE = 16
27
 
28
+ def process(args):
 
 
 
 
 
 
 
 
 
29
 
30
+ dest_images_path = os.path.join(args.dest, 'images')
31
+ dest_labels_path = os.path.join(args.dest, 'labels')
32
 
33
+ mkdir.make_dirs([dest_images_path, dest_labels_path])
34
+ logo_images = []
35
+ logo_alphas = []
36
+ logo_labels = {}
37
 
38
+ db = {}
39
+ with open(defaults.MAIN_CSV_PATH, 'r') as f:
40
+ reader = csv.DictReader(f)
41
+ db = {e.bco: e for e in [Entity.from_dict(d) for d in reader]}
 
42
 
43
+ background_images = [d for d in os.scandir(args.backgrounds)]
44
+ assert(len(background_images))
 
 
 
 
 
45
 
46
+ stats = {
47
+ 'failed': 0,
48
+ 'ok': 0
49
+ }
50
 
51
+ for d in os.scandir(args.logos):
52
+ img = None
53
+ if not d.is_file():
54
+ stats['failed'] += 1
55
+ continue
56
 
57
+ try:
58
+ if filetype.match(d.path, matchers=image_matchers):
59
+ img = cv2.imread(d.path, cv2.IMREAD_UNCHANGED)
60
+ else:
61
+ png = svg2png(url=d.path)
62
+ img = cv2.imdecode(np.asarray(bytearray(png), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
63
+ label = db[d.name.split('.')[0]].id
64
 
65
+ (h, w, c) = img.shape
66
+ if c == 3:
67
+ img = imtool.add_alpha(img)
68
 
69
+ if img.ndim < 3:
70
+ print(f'very bad dim: {img.ndim}')
71
 
72
+ img = imtool.remove_white(img)
73
+ (h, w, c) = img.shape
 
74
 
75
+ assert(w > 10)
76
+ assert(h > 10)
 
77
 
78
+ stats['ok'] += 1
 
 
 
79
 
80
+ (b, g, r, _) = cv2.split(img)
81
+ alpha = img[:, :, 3]/255
82
+ d = cv2.merge([b, g, r])
83
 
84
+ logo_images.append(d)
85
+ # tried id() tried __array_interface__, tried tagging, nothing works
86
+ logo_labels.update({d.tobytes(): label})
87
 
88
+ # XXX(xaiki): we pass alpha as a float32 heatmap,
89
+ # because imgaug is pretty strict about what data it will process
90
+ # and that we want the alpha layer to pass the same transformations as the orig
91
+ logo_alphas.append(np.dstack((alpha, alpha, alpha)).astype('float32'))
92
 
93
+ except Exception as e:
94
+ stats['failed'] += 1
95
+ print(f'error loading: {d.path}: {e}')
 
 
 
96
 
97
+ print(stats)
98
+ #print(len(logo_alphas), len(logo_images), len(logo_labels))
99
+ assert(len(logo_alphas) == len(logo_images))
100
 
101
+ # so that we don't get a lot of the same logos on the same page.
102
+ zipped = list(zip(logo_images, logo_alphas))
103
+ random.shuffle(zipped)
104
+ logo_images, logo_alphas = zip(*zipped)
105
 
106
+ n = len(logo_images)
107
+ batches = []
108
+ for i in range(math.floor(n*2/BATCH_SIZE)):
109
+ s = (i*BATCH_SIZE)%n
110
+ e = min(s + BATCH_SIZE, n)
111
+ le = max(0, BATCH_SIZE - (e - s))
112
 
113
+ a = logo_images[0:le] + logo_images[s:e]
114
+ h = logo_alphas[0:le] + logo_alphas[s:e]
 
115
 
116
+ assert(len(a) == BATCH_SIZE)
 
 
 
117
 
118
+ batches.append(UnnormalizedBatch(images=a,heatmaps=h))
119
 
120
+ bar = ChargingBar('Processing', max=len(batches))
121
+ # We use a single, very fast augmenter here to show that batches
122
+ # are only loaded once there is space again in the buffer.
123
+ pipeline = pipelines.HUGE
124
 
125
+ def create_generator(lst):
126
+ for b in lst:
127
+ print(f"Loading next unaugmented batch...")
128
+ yield b
129
+
130
+ batches_generator = create_generator(batches)
131
+
132
+ with pipeline.pool(processes=-1, seed=1) as pool:
133
+ batches_aug = pool.imap_batches(batches_generator, output_buffer_size=5)
134
+
135
+ print(f"Requesting next augmented batch...")
136
+ for i, batch_aug in enumerate(batches_aug):
137
+ idx = list(range(len(batch_aug.images_aug)))
138
+ random.shuffle(idx)
139
+ for j, d in enumerate(background_images):
140
+ img = imtool.remove_white(cv2.imread(d.path))
141
+ basename = d.name.replace('.png', '') + f'.{i}.{j}'
142
 
143
+ anotations = []
144
+ for k in range(math.floor(len(batch_aug.images_aug)/3)):
145
+ logo_idx = (j+k*4)%len(batch_aug.images_aug)
146
+
147
+ orig = batch_aug.images_unaug[logo_idx]
148
+ label = logo_labels[orig.tobytes()]
149
+ logo = batch_aug.images_aug[logo_idx]
150
 
151
+ assert(logo.shape == orig.shape)
152
+
153
+ # XXX(xaiki): we get alpha from heatmap, but will only use one channel
154
+ # we could make mix_alpha into mix_mask and pass all 3 chanels
155
+ alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
156
 
157
+ try:
158
+ bb = imtool.mix_alpha(img, logo, alpha[0],
159
+ random.random(), random.random())
160
+ c = bb.to_centroid(img.shape)
161
+ anotations.append(c.to_anotation(label))
162
+ except AssertionError as e:
163
+ print(f'couldnt process {i}, {j}: {e}')
164
 
165
  try:
166
+ cv2.imwrite(f'{dest_images_path}/{basename}.png', img)
167
+ label_path = f"{dest_labels_path}/{basename}.txt"
168
+ with open(label_path, 'a') as f:
169
+ f.write('\n'.join(anotations))
170
+ except Exception:
171
+ print(f'couldnt write image {basename}')
172
+
173
+ if i < len(batches)-1:
174
+ print("Requesting next augmented batch...")
175
+ bar.next()
176
+ bar.finish()
177
+
178
+ if __name__ == '__main__':
179
+ import argparse
180
+
181
+ parser = argparse.ArgumentParser(description='mix backgrounds and logos into augmented data for YOLO')
182
+ parser.add_argument('--logos', metavar='logos', type=str,
183
+ default=defaults.LOGOS_DATA_PATH,
184
+ help='dir containing logos')
185
+ parser.add_argument('--backgrounds', metavar='backgrounds', type=str,
186
+
187
+ default=defaults.IMAGES_PATH,
188
+ help='dir containing background plates')
189
+ parser.add_argument('--dst', dest='dest', type=str,
190
+ default=defaults.AUGMENTED_DATA_PATH,
191
+ help='dest dir')
192
+
193
+ args = parser.parse_args()
194
+ process(args)