paraphraser_ai / backend /data_augmenter.py
ULMER Louis (T0240644)
updating paraphraser
51636fd
raw
history blame
No virus
6.45 kB
#%%
import argparse
import time
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import json
import torch
from dotenv import load_dotenv
#%%
load_dotenv()
from nltk.tokenize import sent_tokenize
wd = os.path.dirname(os.path.realpath(__file__))
class BackTranslatorAugmenter:
"""
A class that performs BackTranslation in order to do data augmentation.
For best results we recommend using bottleneck languages (`out_lang`)
such as russian (ru) and
spanish (es).
Example
-------
.. code-block:: python
data_augmenter = BackTranslatorAugmenter(out_lang="es")
text = "I want to augment this sentence"
print(text)
data_augmenter.back_translate(text, verbose=True)
:param in_lang: the text input language, defaults to "en"
:type in_lang: str, optional
:param out_lang: the language to translate with, defaults to "ru"
:type out_lang: str, optional
"""
def __init__(self, in_lang="en", out_lang="ru") -> None:
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.in_tokenizer = AutoTokenizer.from_pretrained(
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
)
self.in_model = AutoModelForSeq2SeqLM.from_pretrained(
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
).to(self.device)
self.out_tokenizer = AutoTokenizer.from_pretrained(
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
)
self.out_model = AutoModelForSeq2SeqLM.from_pretrained(
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
).to(self.device)
def back_translate(self, text, verbose=False):
if verbose:
tic = time.time()
encoded_text = self.in_tokenizer(
text, return_tensors="pt", padding=True, truncation=True, return_overflowing_tokens=True
).to(self.device)
if encoded_text['num_truncated_tokens'][0] > 0:
print('Text is too long ')
return self.back_translate_long(text,verbose=verbose)
in_generated_ids = self.in_model.generate(inputs=encoded_text['input_ids'],
attention_mask=encoded_text["attention_mask"])
in_preds = [
self.in_tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in in_generated_ids
]
if verbose:
print("in_pred : ", in_preds)
encoded_text = self.out_tokenizer(
in_preds, return_tensors="pt", padding=True, truncation=True,return_overflowing_tokens=True
).to(self.device)
out_generated_ids = self.out_model.generate(inputs=encoded_text['input_ids'],
attention_mask=encoded_text["attention_mask"])
out_preds = [
self.out_tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in out_generated_ids
]
if verbose:
tac = time.time()
print("out_pred : ", out_preds)
print("Elapsed time : ", tac - tic)
return out_preds
def back_translate_long(self, text, verbose=False):
sentences = sent_tokenize(text)
return [" ".join(self.back_translate(sentences, verbose=verbose))]
def do_backtranslation(**args):
df = pd.read_csv(args["input_data_path"])[:1]
data_augmenter = BackTranslatorAugmenter(
in_lang=args["in_lang"], out_lang=args["out_lang"]
)
dict_res = {col_name: [] for _, col_name in args["col_map"].items()}
for i in tqdm(range(0, len(df), args["batch_size"])):
for old_col, new_col in args["col_map"].items():
dict_res[new_col] += data_augmenter.back_translate(
list(df[old_col].iloc[i : i + args["batch_size"]])
)
augmented_df = pd.DataFrame(dict_res)
os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True)
augmented_df.to_csv(args["output_data_path"])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Back Translate a dataset for better training"
)
parser.add_argument(
"-in_lang",
type=str,
default="en",
help="""the text input language, defaults to "en",
one can choose between {'es','ru','en','fr','de','pt','zh'}
but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language
pair you ask for is available""",
)
parser.add_argument(
"-out_lang",
type=str,
default="ru",
help="The bottleneck language if you want to resume training one can"
"choose between {'es','ru','en','fr','de','pt','zh'} but please have a "
"look at https://huggingface.co/Helsinki-NLP to make sure the language"
"pair you ask for is available",
)
parser.add_argument(
"-input_data_path",
type=str,
default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"),
help="dataset location, please note it should be a CSV file with two"
'columns : "text" and "summary"',
)
parser.add_argument(
"-output_data_path",
type=str,
default=os.path.join(
wd, "dataset", "augmented_datas", "augmented_dataset_output.csv"
),
help="augmented dataset output location",
)
parser.add_argument(
"-columns_mapping",
"--col_map",
type=json.loads,
default={"abstract": "text", "tldr": "summary"},
help="columns names to apply data augmentation on "
"you have to give a key/value pair dict such that "
"{'input_column_name1':'output_column_name1'} by default "
" it is set as {'abstract': 'text', 'tldr':'summary'}, "
"if you don't want to change the column names,"
" please provide a dict such that keys=values ",
)
parser.add_argument("-batch_size", type=int, default=25, help="batch_size")
args = parser.parse_args()
do_backtranslation(**vars(args))