Iker commited on
Commit
390a692
·
1 Parent(s): 2a897d7

Add --repetition-penalty flag

Browse files
Files changed (1) hide show
  1. translate.py +12 -0
translate.py CHANGED
@@ -76,6 +76,7 @@ def main(
76
  top_p: float = 1.0,
77
  keep_special_tokens: bool = False,
78
  keep_tokenization_spaces: bool = False,
 
79
  ):
80
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
81
 
@@ -132,6 +133,9 @@ def main(
132
  "top_p": top_p,
133
  }
134
 
 
 
 
135
  total_lines: int = count_lines(sentences_path)
136
 
137
  if accelerator.is_main_process:
@@ -351,6 +355,13 @@ if __name__ == "__main__":
351
  help="Do not clean spaces in the decoded text.",
352
  )
353
 
 
 
 
 
 
 
 
354
  args = parser.parse_args()
355
 
356
  main(
@@ -371,4 +382,5 @@ if __name__ == "__main__":
371
  top_p=args.top_p,
372
  keep_special_tokens=args.keep_special_tokens,
373
  keep_tokenization_spaces=args.keep_tokenization_spaces,
 
374
  )
 
76
  top_p: float = 1.0,
77
  keep_special_tokens: bool = False,
78
  keep_tokenization_spaces: bool = False,
79
+ repetition_penalty: float = None,
80
  ):
81
  os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
82
 
 
133
  "top_p": top_p,
134
  }
135
 
136
+ if repetition_penalty is not None:
137
+ gen_kwargs["repetition_penalty"] = repetition_penalty
138
+
139
  total_lines: int = count_lines(sentences_path)
140
 
141
  if accelerator.is_main_process:
 
355
  help="Do not clean spaces in the decoded text.",
356
  )
357
 
358
+ parser.add_argument(
359
+ "--repetition_penalty",
360
+ type=float,
361
+ default=None,
362
+ help="Repetition penalty.",
363
+ )
364
+
365
  args = parser.parse_args()
366
 
367
  main(
 
382
  top_p=args.top_p,
383
  keep_special_tokens=args.keep_special_tokens,
384
  keep_tokenization_spaces=args.keep_tokenization_spaces,
385
+ repetition_penalty=args.repetition_penalty,
386
  )