Iker commited on
Commit
01e13d9
1 Parent(s): 329e788

Fix multi-GPU batch split bug

Browse files
Files changed (1) hide show
  1. translate.py +7 -5
translate.py CHANGED
@@ -70,7 +70,9 @@ def main(
70
  os.makedirs(os.path.abspath(os.path.dirname(output_path)))
71
 
72
  accelerator = Accelerator(
73
- mixed_precision=precision if precision != "32" else "no", split_batches=True
 
 
74
  )
75
 
76
  print(f"Loading tokenizer {model_name}...")
@@ -182,7 +184,7 @@ def main(
182
  if accelerator.is_main_process:
183
  if step == len(data_loader) - 1:
184
  tgt_text = tgt_text[
185
- : len(data_loader.dataset) * num_return_sequences
186
  - samples_seen
187
  ]
188
  else:
@@ -287,21 +289,21 @@ if __name__ == "__main__":
287
  parser.add_argument(
288
  "--temperature",
289
  type=float,
290
- default=1.0,
291
  help="Temperature for sampling, value used only if do_sample is True.",
292
  )
293
 
294
  parser.add_argument(
295
  "--top_k",
296
  type=int,
297
- default=50,
298
  help="If do_sample is True, will sample from the top k most likely tokens.",
299
  )
300
 
301
  parser.add_argument(
302
  "--top_p",
303
  type=float,
304
- default=1.0,
305
  help="If do_sample is True, will sample from the top k most likely tokens.",
306
  )
307
 
 
70
  os.makedirs(os.path.abspath(os.path.dirname(output_path)))
71
 
72
  accelerator = Accelerator(
73
+ mixed_precision=precision if precision != "32" else "no",
74
+ split_batches=False,
75
+ dispatch_batches=True,
76
  )
77
 
78
  print(f"Loading tokenizer {model_name}...")
 
184
  if accelerator.is_main_process:
185
  if step == len(data_loader) - 1:
186
  tgt_text = tgt_text[
187
+ : (len(data_loader.dataset) * num_return_sequences)
188
  - samples_seen
189
  ]
190
  else:
 
289
  parser.add_argument(
290
  "--temperature",
291
  type=float,
292
+ default=0.8,
293
  help="Temperature for sampling, value used only if do_sample is True.",
294
  )
295
 
296
  parser.add_argument(
297
  "--top_k",
298
  type=int,
299
+ default=100,
300
  help="If do_sample is True, will sample from the top k most likely tokens.",
301
  )
302
 
303
  parser.add_argument(
304
  "--top_p",
305
  type=float,
306
+ default=0.75,
307
  help="If do_sample is True, will sample from the top k most likely tokens.",
308
  )
309