Spaces:
Running
Running
Add --keep_tokenization_spaces argument to control the space decoding
Browse files- translate.py +14 -4
translate.py
CHANGED
@@ -31,7 +31,6 @@ def get_dataloader(
|
|
31 |
batch_size: int,
|
32 |
max_length: int,
|
33 |
) -> DataLoader:
|
34 |
-
|
35 |
dataset = DatasetReader(filename, tokenizer, max_length)
|
36 |
if accelerator.distributed_type == DistributedType.TPU:
|
37 |
data_collator = DataCollatorForSeq2Seq(
|
@@ -76,8 +75,8 @@ def main(
|
|
76 |
top_k: int = 50,
|
77 |
top_p: float = 1.0,
|
78 |
keep_special_tokens: bool = False,
|
|
|
79 |
):
|
80 |
-
|
81 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
82 |
|
83 |
accelerator = Accelerator(
|
@@ -149,6 +148,8 @@ def main(
|
|
149 |
f"Max length: {max_length}\n"
|
150 |
f"Precision: {model.dtype}\n"
|
151 |
f"Model: {model_name}\n"
|
|
|
|
|
152 |
)
|
153 |
print("** Generation parameters **")
|
154 |
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
|
@@ -197,7 +198,9 @@ def main(
|
|
197 |
)
|
198 |
|
199 |
tgt_text = tokenizer.batch_decode(
|
200 |
-
generated_tokens,
|
|
|
|
|
201 |
)
|
202 |
if accelerator.is_main_process:
|
203 |
if (
|
@@ -342,6 +345,12 @@ if __name__ == "__main__":
|
|
342 |
help="Keep special tokens in the decoded text.",
|
343 |
)
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
args = parser.parse_args()
|
346 |
|
347 |
main(
|
@@ -360,5 +369,6 @@ if __name__ == "__main__":
|
|
360 |
temperature=args.temperature,
|
361 |
top_k=args.top_k,
|
362 |
top_p=args.top_p,
|
363 |
-
keep_special_tokens=args.keep_special_tokens
|
|
|
364 |
)
|
|
|
31 |
batch_size: int,
|
32 |
max_length: int,
|
33 |
) -> DataLoader:
|
|
|
34 |
dataset = DatasetReader(filename, tokenizer, max_length)
|
35 |
if accelerator.distributed_type == DistributedType.TPU:
|
36 |
data_collator = DataCollatorForSeq2Seq(
|
|
|
75 |
top_k: int = 50,
|
76 |
top_p: float = 1.0,
|
77 |
keep_special_tokens: bool = False,
|
78 |
+
keep_tokenization_spaces: bool = False,
|
79 |
):
|
|
|
80 |
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
|
81 |
|
82 |
accelerator = Accelerator(
|
|
|
148 |
f"Max length: {max_length}\n"
|
149 |
f"Precision: {model.dtype}\n"
|
150 |
f"Model: {model_name}\n"
|
151 |
+
f"Keep special tokens: {keep_special_tokens}\n"
|
152 |
+
f"Keep tokenization spaces: {keep_tokenization_spaces}\n"
|
153 |
)
|
154 |
print("** Generation parameters **")
|
155 |
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
|
|
|
198 |
)
|
199 |
|
200 |
tgt_text = tokenizer.batch_decode(
|
201 |
+
generated_tokens,
|
202 |
+
skip_special_tokens=not keep_special_tokens,
|
203 |
+
clean_up_tokenization_spaces=not keep_tokenization_spaces,
|
204 |
)
|
205 |
if accelerator.is_main_process:
|
206 |
if (
|
|
|
345 |
help="Keep special tokens in the decoded text.",
|
346 |
)
|
347 |
|
348 |
+
parser.add_argument(
|
349 |
+
"--keep_tokenization_spaces",
|
350 |
+
action="store_true",
|
351 |
+
help="Do not clean spaces in the decoded text.",
|
352 |
+
)
|
353 |
+
|
354 |
args = parser.parse_args()
|
355 |
|
356 |
main(
|
|
|
369 |
temperature=args.temperature,
|
370 |
top_k=args.top_k,
|
371 |
top_p=args.top_p,
|
372 |
+
keep_special_tokens=args.keep_special_tokens,
|
373 |
+
keep_tokenization_spaces=args.keep_tokenization_spaces,
|
374 |
)
|