Iker commited on
Commit
2a897d7
·
1 Parent(s): b32de48

Add --keep_tokenization_spaces argument to control the space decoding

Browse files
Files changed (1) hide show
  1. translate.py +14 -4
translate.py CHANGED
@@ -31,7 +31,6 @@ def get_dataloader(
31
  batch_size: int,
32
  max_length: int,
33
  ) -> DataLoader:
34
-
35
  dataset = DatasetReader(filename, tokenizer, max_length)
36
  if accelerator.distributed_type == DistributedType.TPU:
37
  data_collator = DataCollatorForSeq2Seq(
@@ -76,8 +75,8 @@ def main(
76
  top_k: int = 50,
77
  top_p: float = 1.0,
78
  keep_special_tokens: bool = False,
 
79
  ):
80
-
81
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
82
 
83
  accelerator = Accelerator(
@@ -149,6 +148,8 @@ def main(
149
  f"Max length: {max_length}\n"
150
  f"Precision: {model.dtype}\n"
151
  f"Model: {model_name}\n"
 
 
152
  )
153
  print("** Generation parameters **")
154
  print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
@@ -197,7 +198,9 @@ def main(
197
  )
198
 
199
  tgt_text = tokenizer.batch_decode(
200
- generated_tokens, skip_special_tokens=not keep_special_tokens
 
 
201
  )
202
  if accelerator.is_main_process:
203
  if (
@@ -342,6 +345,12 @@ if __name__ == "__main__":
342
  help="Keep special tokens in the decoded text.",
343
  )
344
 
 
 
 
 
 
 
345
  args = parser.parse_args()
346
 
347
  main(
@@ -360,5 +369,6 @@ if __name__ == "__main__":
360
  temperature=args.temperature,
361
  top_k=args.top_k,
362
  top_p=args.top_p,
363
- keep_special_tokens=args.keep_special_tokens
 
364
  )
 
31
  batch_size: int,
32
  max_length: int,
33
  ) -> DataLoader:
 
34
  dataset = DatasetReader(filename, tokenizer, max_length)
35
  if accelerator.distributed_type == DistributedType.TPU:
36
  data_collator = DataCollatorForSeq2Seq(
 
75
  top_k: int = 50,
76
  top_p: float = 1.0,
77
  keep_special_tokens: bool = False,
78
+ keep_tokenization_spaces: bool = False,
79
  ):
 
80
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
81
 
82
  accelerator = Accelerator(
 
148
  f"Max length: {max_length}\n"
149
  f"Precision: {model.dtype}\n"
150
  f"Model: {model_name}\n"
151
+ f"Keep special tokens: {keep_special_tokens}\n"
152
+ f"Keep tokenization spaces: {keep_tokenization_spaces}\n"
153
  )
154
  print("** Generation parameters **")
155
  print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
 
198
  )
199
 
200
  tgt_text = tokenizer.batch_decode(
201
+ generated_tokens,
202
+ skip_special_tokens=not keep_special_tokens,
203
+ clean_up_tokenization_spaces=not keep_tokenization_spaces,
204
  )
205
  if accelerator.is_main_process:
206
  if (
 
345
  help="Keep special tokens in the decoded text.",
346
  )
347
 
348
+ parser.add_argument(
349
+ "--keep_tokenization_spaces",
350
+ action="store_true",
351
+ help="Do not clean spaces in the decoded text.",
352
+ )
353
+
354
  args = parser.parse_args()
355
 
356
  main(
 
369
  temperature=args.temperature,
370
  top_k=args.top_k,
371
  top_p=args.top_p,
372
+ keep_special_tokens=args.keep_special_tokens,
373
+ keep_tokenization_spaces=args.keep_tokenization_spaces,
374
  )