easy-translate / translate.py
Iker's picture
Add --repetition-penalty flag
390a692
raw
history blame
12 kB
import os
import math
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedTokenizerBase,
DataCollatorForSeq2Seq,
)
from dataset import DatasetReader, count_lines
from accelerate import Accelerator, DistributedType, find_executable_batch_size
def encode_string(text):
return text.replace("\r", r"\r").replace("\n", r"\n").replace("\t", r"\t")
def get_dataloader(
accelerator: Accelerator,
filename: str,
tokenizer: PreTrainedTokenizerBase,
batch_size: int,
max_length: int,
) -> DataLoader:
dataset = DatasetReader(filename, tokenizer, max_length)
if accelerator.distributed_type == DistributedType.TPU:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
padding="max_length",
max_length=max_length,
label_pad_token_id=tokenizer.pad_token_id,
return_tensors="pt",
)
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
padding=True,
label_pad_token_id=tokenizer.pad_token_id,
# max_length=max_length, No need to set max_length here, we already truncate in the preprocess function
pad_to_multiple_of=8,
return_tensors="pt",
)
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=data_collator,
num_workers=0, # Disable multiprocessing
)
def main(
sentences_path: str,
output_path: str,
source_lang: str,
target_lang: str,
starting_batch_size: int,
model_name: str = "facebook/m2m100_1.2B",
cache_dir: str = None,
precision: str = "32",
max_length: int = 128,
num_beams: int = 4,
num_return_sequences: int = 1,
do_sample: bool = False,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
keep_special_tokens: bool = False,
keep_tokenization_spaces: bool = False,
repetition_penalty: float = None,
):
os.makedirs(os.path.abspath(os.path.dirname(output_path)), exist_ok=True)
accelerator = Accelerator(
mixed_precision=precision if precision != "32" else "no",
split_batches=False,
dispatch_batches=False,
)
print(f"Loading tokenizer {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
)
print(f"Loading model {model_name}...")
model = AutoModelForSeq2SeqLM.from_pretrained(
pretrained_model_name_or_path=model_name, cache_dir=cache_dir
)
model.eval()
print(f"Preparing data...\n")
if precision == "32":
model = model.float()
elif precision == "fp16":
model = model.half()
elif precision == "bf16":
model = model.bfloat16()
else:
raise ValueError("Precision not supported. Supported values: 32, fp16, bf16")
try:
_ = tokenizer.lang_code_to_id[source_lang]
except KeyError:
raise KeyError(
f"Language {source_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
)
tokenizer.src_lang = source_lang
try:
lang_code_to_idx = tokenizer.lang_code_to_id[target_lang]
except KeyError:
raise KeyError(
f"Language {target_lang} not found in tokenizer. Available languages: {tokenizer.lang_code_to_id.keys()}"
)
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"num_return_sequences": num_return_sequences,
"do_sample": do_sample,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
if repetition_penalty is not None:
gen_kwargs["repetition_penalty"] = repetition_penalty
total_lines: int = count_lines(sentences_path)
if accelerator.is_main_process:
print(
f"** Translation **\n"
f"Input file: {sentences_path}\n"
f"Output file: {output_path}\n"
f"Source language: {source_lang}\n"
f"Target language: {target_lang}\n"
f"Starting batch size: {starting_batch_size}\n"
f"Device: {str(accelerator.device).split(':')[0]}\n"
f"Num. Devices: {accelerator.num_processes}\n"
f"Distributed_type: {accelerator.distributed_type}\n"
f"Max length: {max_length}\n"
f"Precision: {model.dtype}\n"
f"Model: {model_name}\n"
f"Keep special tokens: {keep_special_tokens}\n"
f"Keep tokenization spaces: {keep_tokenization_spaces}\n"
)
print("** Generation parameters **")
print("\n".join(f"{k}: {v}" for k, v in gen_kwargs.items()))
print("\n")
@find_executable_batch_size(starting_batch_size=starting_batch_size)
def inference(batch_size):
nonlocal model, tokenizer, sentences_path, max_length, output_path, lang_code_to_idx, gen_kwargs, precision
print(f"Translating with batch size {batch_size}")
data_loader = get_dataloader(
accelerator=accelerator,
filename=sentences_path,
tokenizer=tokenizer,
batch_size=batch_size,
max_length=max_length,
)
model, data_loader = accelerator.prepare(model, data_loader)
samples_seen: int = 0
with tqdm(
total=total_lines,
desc="Dataset translation",
leave=True,
ascii=True,
disable=(not accelerator.is_main_process),
) as pbar, open(output_path, "w", encoding="utf-8") as output_file:
with torch.no_grad():
for step, batch in enumerate(data_loader):
batch["input_ids"] = batch["input_ids"]
batch["attention_mask"] = batch["attention_mask"]
generated_tokens = accelerator.unwrap_model(model).generate(
**batch, forced_bos_token_id=lang_code_to_idx, **gen_kwargs
)
generated_tokens = accelerator.pad_across_processes(
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
)
generated_tokens = (
accelerator.gather(generated_tokens).cpu().numpy()
)
tgt_text = tokenizer.batch_decode(
generated_tokens,
skip_special_tokens=not keep_special_tokens,
clean_up_tokenization_spaces=not keep_tokenization_spaces,
)
if accelerator.is_main_process:
if (
step
== math.ceil(
math.ceil(total_lines / batch_size)
/ accelerator.num_processes
)
- 1
):
tgt_text = tgt_text[
: (total_lines * num_return_sequences) - samples_seen
]
else:
samples_seen += len(tgt_text)
print(
"\n".join(
[encode_string(sentence) for sentence in tgt_text]
),
file=output_file,
)
pbar.update(len(tgt_text) // gen_kwargs["num_return_sequences"])
inference()
print(f"Translation done.\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the translation experiments")
parser.add_argument(
"--sentences_path",
type=str,
required=True,
help="Path to a txt file containing the sentences to translate. One sentence per line.",
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to a txt file where the translated sentences will be written.",
)
parser.add_argument(
"--source_lang",
type=str,
required=True,
help="Source language id. See: supported_languages.md",
)
parser.add_argument(
"--target_lang",
type=str,
required=True,
help="Target language id. See: supported_languages.md",
)
parser.add_argument(
"--starting_batch_size",
type=int,
default=128,
help="Starting batch size, we will automatically reduce it if we find an OOM error."
"If you use multiple devices, we will divide this number by the number of devices.",
)
parser.add_argument(
"--model_name",
type=str,
default="facebook/m2m100_1.2B",
help="Path to the model to use. See: https://huggingface.co/models",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="Cache directory from which to load the model, or None to not cache",
)
parser.add_argument(
"--max_length",
type=int,
default=128,
help="Maximum number of tokens in the source sentence and generated sentence. "
"Increase this value to translate longer sentences, at the cost of increasing memory usage.",
)
parser.add_argument(
"--num_beams",
type=int,
default=5,
help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
)
parser.add_argument(
"--num_return_sequences",
type=int,
default=1,
help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).",
)
parser.add_argument(
"--precision",
type=str,
default="32",
choices=["bf16", "fp16", "32"],
help="Precision of the model. bf16, fp16 or 32.",
)
parser.add_argument(
"--do_sample",
action="store_true",
help="Use sampling instead of beam search.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="Temperature for sampling, value used only if do_sample is True.",
)
parser.add_argument(
"--top_k",
type=int,
default=100,
help="If do_sample is True, will sample from the top k most likely tokens.",
)
parser.add_argument(
"--top_p",
type=float,
default=0.75,
help="If do_sample is True, will sample from the top k most likely tokens.",
)
parser.add_argument(
"--keep_special_tokens",
action="store_true",
help="Keep special tokens in the decoded text.",
)
parser.add_argument(
"--keep_tokenization_spaces",
action="store_true",
help="Do not clean spaces in the decoded text.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=None,
help="Repetition penalty.",
)
args = parser.parse_args()
main(
sentences_path=args.sentences_path,
output_path=args.output_path,
source_lang=args.source_lang,
target_lang=args.target_lang,
starting_batch_size=args.starting_batch_size,
model_name=args.model_name,
cache_dir=args.cache_dir,
max_length=args.max_length,
num_beams=args.num_beams,
num_return_sequences=args.num_return_sequences,
precision=args.precision,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
keep_special_tokens=args.keep_special_tokens,
keep_tokenization_spaces=args.keep_tokenization_spaces,
repetition_penalty=args.repetition_penalty,
)