import argparse import time from tqdm import tqdm import pandas as pd from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import os import json import torch from dotenv import load_dotenv load_dotenv() from nltk.tokenize import sent_tokenize wd = os.path.dirname(os.path.realpath(__file__)) class BackTranslatorAugmenter: """ A class that performs BackTranslation in order to do data augmentation. For best results we recommend using bottleneck languages (`out_lang`) such as russian (ru) and spanish (es). Example ------- .. code-block:: python data_augmenter = BackTranslatorAugmenter(out_lang="es") text = "I want to augment this sentence" print(text) data_augmenter.back_translate(text, verbose=True) :param in_lang: the text input language, defaults to "en" :type in_lang: str, optional :param out_lang: the language to translate with, defaults to "ru" :type out_lang: str, optional """ def __init__(self, in_lang="en", out_lang="ru") -> None: if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" self.in_tokenizer = AutoTokenizer.from_pretrained( f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", cache_dir=os.getenv("TRANSFORMERS_CACHE"), ) self.in_model = AutoModelForSeq2SeqLM.from_pretrained( f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", cache_dir=os.getenv("TRANSFORMERS_CACHE"), ).to(self.device) self.out_tokenizer = AutoTokenizer.from_pretrained( f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", cache_dir=os.getenv("TRANSFORMERS_CACHE"), ) self.out_model = AutoModelForSeq2SeqLM.from_pretrained( f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", cache_dir=os.getenv("TRANSFORMERS_CACHE"), ).to(self.device) def back_translate(self, text, verbose=False): if verbose: tic = time.time() encoded_text = self.in_tokenizer( text, return_tensors="pt", padding=True, truncation=True ).to(self.device) in_generated_ids = self.in_model.generate(**encoded_text) in_preds = [ self.in_tokenizer.decode( gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True ) for gen_id in in_generated_ids ] if verbose: print("in_pred : ", in_preds) encoded_text = self.out_tokenizer( in_preds, return_tensors="pt", padding=True, truncation=True ).to(self.device) out_generated_ids = self.out_model.generate(**encoded_text) out_preds = [ self.out_tokenizer.decode( gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True ) for gen_id in out_generated_ids ] if verbose: tac = time.time() print("out_pred : ", out_preds) print("Elapsed time : ", tac - tic) return out_preds def back_translate_long(self, text, verbose=False): sentences = sent_tokenize(text) return [" ".join(self.back_translate(sentences, verbose=verbose))] def do_backtranslation(**args): df = pd.read_csv(args["input_data_path"])[:1] data_augmenter = BackTranslatorAugmenter( in_lang=args["in_lang"], out_lang=args["out_lang"] ) dict_res = {col_name: [] for _, col_name in args["col_map"].items()} for i in tqdm(range(0, len(df), args["batch_size"])): for old_col, new_col in args["col_map"].items(): dict_res[new_col] += data_augmenter.back_translate( list(df[old_col].iloc[i : i + args["batch_size"]]) ) augmented_df = pd.DataFrame(dict_res) os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True) augmented_df.to_csv(args["output_data_path"]) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Back Translate a dataset for better training" ) parser.add_argument( "-in_lang", type=str, default="en", help="""the text input language, defaults to "en", one can choose between {'es','ru','en','fr','de','pt','zh'} but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language pair you ask for is available""", ) parser.add_argument( "-out_lang", type=str, default="ru", help="The bottleneck language if you want to resume training one can" "choose between {'es','ru','en','fr','de','pt','zh'} but please have a " "look at https://huggingface.co/Helsinki-NLP to make sure the language" "pair you ask for is available", ) parser.add_argument( "-input_data_path", type=str, default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"), help="dataset location, please note it should be a CSV file with two" 'columns : "text" and "summary"', ) parser.add_argument( "-output_data_path", type=str, default=os.path.join( wd, "dataset", "augmented_datas", "augmented_dataset_output.csv" ), help="augmented dataset output location", ) parser.add_argument( "-columns_mapping", "--col_map", type=json.loads, default={"abstract": "text", "tldr": "summary"}, help="columns names to apply data augmentation on " "you have to give a key/value pair dict such that " "{'input_column_name1':'output_column_name1'} by default " " it is set as {'abstract': 'text', 'tldr':'summary'}, " "if you don't want to change the column names," " please provide a dict such that keys=values ", ) parser.add_argument("-batch_size", type=int, default=25, help="batch_size") args = parser.parse_args() do_backtranslation(**vars(args))