Niv Sardi commited on
Commit
9477f68
1 Parent(s): 05c71b4

augment: correctly pass alpha and labels

Browse files

Signed-off-by: Niv Sardi <xaiki@evilgiggle.com>

Files changed (2) hide show
  1. python/augment.py +24 -4
  2. python/imtool.py +1 -1
python/augment.py CHANGED
@@ -88,8 +88,27 @@ for d in os.scandir(defaults.LOGOS_DATA_PATH):
88
  print(f'error loading: {d.path}: {e}')
89
 
90
  print(stats)
91
- batches = [UnnormalizedBatch(images=logo_images[i:i+BATCH_SIZE],heatmaps=logo_alphas[i:i+BATCH_SIZE])
92
- for i in range(math.floor(len(logo_images)/BATCH_SIZE))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # We use a single, very fast augmenter here to show that batches
95
  # are only loaded once there is space again in the buffer.
@@ -125,8 +144,9 @@ with pipeline.pool(processes=-1, seed=1) as pool:
125
  # we could make mix_alpha into mix_mask and pass all 3 chanels
126
  alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
127
  try:
128
- img, bb, (w, h) = imtool.mix_alpha(img, logo, alpha[0], random.random(), random.random())
129
- c = bb.to_centroid((h, w, 1))
 
130
  anotations.append(c.to_anotation(label))
131
  except AssertionError as e:
132
  print(f'couldnt process {i}, {j}: {e}')
 
88
  print(f'error loading: {d.path}: {e}')
89
 
90
  print(stats)
91
+ #print(len(logo_alphas), len(logo_images), len(logo_labels))
92
+ assert(len(logo_alphas) == len(logo_images))
93
+
94
+ # so that we don't get a lot of the same logos on the same page.
95
+ zipped = list(zip(logo_images, logo_alphas))
96
+ random.shuffle(zipped)
97
+ logo_images, logo_alphas = zip(*zipped)
98
+
99
+ n = len(logo_images)
100
+ batches = []
101
+ for i in range(math.floor(n*2/BATCH_SIZE)):
102
+ s = (i*BATCH_SIZE)%n
103
+ e = min(s + BATCH_SIZE, n)
104
+ le = max(0, BATCH_SIZE - (e - s))
105
+
106
+ a = logo_images[0:le] + logo_images[s:e]
107
+ h = logo_alphas[0:le] + logo_alphas[s:e]
108
+
109
+ assert(len(a) == BATCH_SIZE)
110
+
111
+ batches.append(UnnormalizedBatch(images=a,heatmaps=h))
112
 
113
  # We use a single, very fast augmenter here to show that batches
114
  # are only loaded once there is space again in the buffer.
 
144
  # we could make mix_alpha into mix_mask and pass all 3 chanels
145
  alpha = cv2.split(batch_aug.heatmaps_aug[logo_idx])
146
  try:
147
+ bb = imtool.mix_alpha(img, logo, alpha[0],
148
+ random.random(), random.random())
149
+ c = bb.to_centroid(img.shape)
150
  anotations.append(c.to_anotation(label))
151
  except AssertionError as e:
152
  print(f'couldnt process {i}, {j}: {e}')
python/imtool.py CHANGED
@@ -183,7 +183,7 @@ def _mix_alpha(a, b, ba, fx, fy):
183
 
184
  a[y:y+bh,x:x+bw] = mat * (1 - mask) + cols * mask
185
 
186
- return a, BoundingBox(x, y, bw, bh), (aw, ah)
187
 
188
  def crop(id, fn, logos: List[Centroid], out = './data/squares'):
189
  basename = os.path.basename(fn).replace('.png', '')
 
183
 
184
  a[y:y+bh,x:x+bw] = mat * (1 - mask) + cols * mask
185
 
186
+ return BoundingBox(x, y, bw, bh)
187
 
188
  def crop(id, fn, logos: List[Centroid], out = './data/squares'):
189
  basename = os.path.basename(fn).replace('.png', '')