Spaces:
Runtime error
Runtime error
import argparse | |
import time | |
from tqdm import tqdm | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import os | |
import json | |
import torch | |
from dotenv import load_dotenv | |
load_dotenv() | |
from nltk.tokenize import sent_tokenize | |
wd = os.path.dirname(os.path.realpath(__file__)) | |
class BackTranslatorAugmenter: | |
""" | |
A class that performs BackTranslation in order to do data augmentation. | |
For best results we recommend using bottleneck languages (`out_lang`) | |
such as russian (ru) and | |
spanish (es). | |
Example | |
------- | |
.. code-block:: python | |
data_augmenter = BackTranslatorAugmenter(out_lang="es") | |
text = "I want to augment this sentence" | |
print(text) | |
data_augmenter.back_translate(text, verbose=True) | |
:param in_lang: the text input language, defaults to "en" | |
:type in_lang: str, optional | |
:param out_lang: the language to translate with, defaults to "ru" | |
:type out_lang: str, optional | |
""" | |
def __init__(self, in_lang="en", out_lang="ru") -> None: | |
if torch.cuda.is_available(): | |
self.device = "cuda" | |
else: | |
self.device = "cpu" | |
self.in_tokenizer = AutoTokenizer.from_pretrained( | |
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", | |
cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
) | |
self.in_model = AutoModelForSeq2SeqLM.from_pretrained( | |
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", | |
cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
).to(self.device) | |
self.out_tokenizer = AutoTokenizer.from_pretrained( | |
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", | |
cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
) | |
self.out_model = AutoModelForSeq2SeqLM.from_pretrained( | |
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", | |
cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
).to(self.device) | |
def back_translate(self, text, verbose=False): | |
if verbose: | |
tic = time.time() | |
encoded_text = self.in_tokenizer( | |
text, return_tensors="pt", padding=True, truncation=True | |
).to(self.device) | |
in_generated_ids = self.in_model.generate(**encoded_text) | |
in_preds = [ | |
self.in_tokenizer.decode( | |
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
for gen_id in in_generated_ids | |
] | |
if verbose: | |
print("in_pred : ", in_preds) | |
encoded_text = self.out_tokenizer( | |
in_preds, return_tensors="pt", padding=True, truncation=True | |
).to(self.device) | |
out_generated_ids = self.out_model.generate(**encoded_text) | |
out_preds = [ | |
self.out_tokenizer.decode( | |
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
for gen_id in out_generated_ids | |
] | |
if verbose: | |
tac = time.time() | |
print("out_pred : ", out_preds) | |
print("Elapsed time : ", tac - tic) | |
return out_preds | |
def back_translate_long(self, text, verbose=False): | |
sentences = sent_tokenize(text) | |
return [" ".join(self.back_translate(sentences, verbose=verbose))] | |
def do_backtranslation(**args): | |
df = pd.read_csv(args["input_data_path"])[:1] | |
data_augmenter = BackTranslatorAugmenter( | |
in_lang=args["in_lang"], out_lang=args["out_lang"] | |
) | |
dict_res = {col_name: [] for _, col_name in args["col_map"].items()} | |
for i in tqdm(range(0, len(df), args["batch_size"])): | |
for old_col, new_col in args["col_map"].items(): | |
dict_res[new_col] += data_augmenter.back_translate( | |
list(df[old_col].iloc[i : i + args["batch_size"]]) | |
) | |
augmented_df = pd.DataFrame(dict_res) | |
os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True) | |
augmented_df.to_csv(args["output_data_path"]) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Back Translate a dataset for better training" | |
) | |
parser.add_argument( | |
"-in_lang", | |
type=str, | |
default="en", | |
help="""the text input language, defaults to "en", | |
one can choose between {'es','ru','en','fr','de','pt','zh'} | |
but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language | |
pair you ask for is available""", | |
) | |
parser.add_argument( | |
"-out_lang", | |
type=str, | |
default="ru", | |
help="The bottleneck language if you want to resume training one can" | |
"choose between {'es','ru','en','fr','de','pt','zh'} but please have a " | |
"look at https://huggingface.co/Helsinki-NLP to make sure the language" | |
"pair you ask for is available", | |
) | |
parser.add_argument( | |
"-input_data_path", | |
type=str, | |
default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"), | |
help="dataset location, please note it should be a CSV file with two" | |
'columns : "text" and "summary"', | |
) | |
parser.add_argument( | |
"-output_data_path", | |
type=str, | |
default=os.path.join( | |
wd, "dataset", "augmented_datas", "augmented_dataset_output.csv" | |
), | |
help="augmented dataset output location", | |
) | |
parser.add_argument( | |
"-columns_mapping", | |
"--col_map", | |
type=json.loads, | |
default={"abstract": "text", "tldr": "summary"}, | |
help="columns names to apply data augmentation on " | |
"you have to give a key/value pair dict such that " | |
"{'input_column_name1':'output_column_name1'} by default " | |
" it is set as {'abstract': 'text', 'tldr':'summary'}, " | |
"if you don't want to change the column names," | |
" please provide a dict such that keys=values ", | |
) | |
parser.add_argument("-batch_size", type=int, default=25, help="batch_size") | |
args = parser.parse_args() | |
do_backtranslation(**vars(args)) | |