Niv Sardi commited on
Commit
a7ac778
1 Parent(s): b1d65c2

python/augment: support --parallel

Browse files
Files changed (1) hide show
  1. python/augment.py +5 -1
python/augment.py CHANGED
@@ -24,6 +24,7 @@ import imtool
24
  import pipelines
25
 
26
  BATCH_SIZE = 16
 
27
 
28
  def process(args):
29
  dest_images_path = os.path.join(args.dest, 'images')
@@ -129,7 +130,7 @@ def process(args):
129
  batches_generator = create_generator(batches)
130
 
131
  batch = 0
132
- with pipeline.pool(processes=-1, seed=1) as pool:
133
  batches_aug = pool.imap_batches(batches_generator, output_buffer_size=5)
134
 
135
  print(f"Requesting next augmented batch...{batch}/{len(batches)}")
@@ -196,6 +197,9 @@ if __name__ == '__main__':
196
  parser.add_argument('--dst', dest='dest', type=str,
197
  default=defaults.AUGMENTED_DATA_PATH,
198
  help='dest dir')
 
 
 
199
 
200
  args = parser.parse_args()
201
  process(args)
 
24
  import pipelines
25
 
26
  BATCH_SIZE = 16
27
+ PARALLEL = 20
28
 
29
  def process(args):
30
  dest_images_path = os.path.join(args.dest, 'images')
 
130
  batches_generator = create_generator(batches)
131
 
132
  batch = 0
133
+ with pipeline.pool(processes=args.parallel, seed=1) as pool:
134
  batches_aug = pool.imap_batches(batches_generator, output_buffer_size=5)
135
 
136
  print(f"Requesting next augmented batch...{batch}/{len(batches)}")
 
197
  parser.add_argument('--dst', dest='dest', type=str,
198
  default=defaults.AUGMENTED_DATA_PATH,
199
  help='dest dir')
200
+ parser.add_argument('--parallel', metavar='parallel', type=int,
201
+ default=PARALLEL,
202
+ help='number of concurrent jobs')
203
 
204
  args = parser.parse_args()
205
  process(args)