paraphraser_ai / backend /data_augmenter.py
ULMER Louis (T0240644)
pushing the app
05e69cc
raw
history blame
6.05 kB
import argparse
import time
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import json
import torch
from dotenv import load_dotenv
load_dotenv()
from nltk.tokenize import sent_tokenize
wd = os.path.dirname(os.path.realpath(__file__))
class BackTranslatorAugmenter:
"""
A class that performs BackTranslation in order to do data augmentation.
For best results we recommend using bottleneck languages (`out_lang`)
such as russian (ru) and
spanish (es).
Example
-------
.. code-block:: python
data_augmenter = BackTranslatorAugmenter(out_lang="es")
text = "I want to augment this sentence"
print(text)
data_augmenter.back_translate(text, verbose=True)
:param in_lang: the text input language, defaults to "en"
:type in_lang: str, optional
:param out_lang: the language to translate with, defaults to "ru"
:type out_lang: str, optional
"""
def __init__(self, in_lang="en", out_lang="ru") -> None:
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.in_tokenizer = AutoTokenizer.from_pretrained(
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
)
self.in_model = AutoModelForSeq2SeqLM.from_pretrained(
f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
).to(self.device)
self.out_tokenizer = AutoTokenizer.from_pretrained(
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
)
self.out_model = AutoModelForSeq2SeqLM.from_pretrained(
f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}",
cache_dir=os.getenv("TRANSFORMERS_CACHE"),
).to(self.device)
def back_translate(self, text, verbose=False):
if verbose:
tic = time.time()
encoded_text = self.in_tokenizer(
text, return_tensors="pt", padding=True, truncation=True
).to(self.device)
in_generated_ids = self.in_model.generate(**encoded_text)
in_preds = [
self.in_tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in in_generated_ids
]
if verbose:
print("in_pred : ", in_preds)
encoded_text = self.out_tokenizer(
in_preds, return_tensors="pt", padding=True, truncation=True
).to(self.device)
out_generated_ids = self.out_model.generate(**encoded_text)
out_preds = [
self.out_tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in out_generated_ids
]
if verbose:
tac = time.time()
print("out_pred : ", out_preds)
print("Elapsed time : ", tac - tic)
return out_preds
def back_translate_long(self, text, verbose=False):
sentences = sent_tokenize(text)
return [" ".join(self.back_translate(sentences, verbose=verbose))]
def do_backtranslation(**args):
df = pd.read_csv(args["input_data_path"])[:1]
data_augmenter = BackTranslatorAugmenter(
in_lang=args["in_lang"], out_lang=args["out_lang"]
)
dict_res = {col_name: [] for _, col_name in args["col_map"].items()}
for i in tqdm(range(0, len(df), args["batch_size"])):
for old_col, new_col in args["col_map"].items():
dict_res[new_col] += data_augmenter.back_translate(
list(df[old_col].iloc[i : i + args["batch_size"]])
)
augmented_df = pd.DataFrame(dict_res)
os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True)
augmented_df.to_csv(args["output_data_path"])
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Back Translate a dataset for better training"
)
parser.add_argument(
"-in_lang",
type=str,
default="en",
help="""the text input language, defaults to "en",
one can choose between {'es','ru','en','fr','de','pt','zh'}
but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language
pair you ask for is available""",
)
parser.add_argument(
"-out_lang",
type=str,
default="ru",
help="The bottleneck language if you want to resume training one can"
"choose between {'es','ru','en','fr','de','pt','zh'} but please have a "
"look at https://huggingface.co/Helsinki-NLP to make sure the language"
"pair you ask for is available",
)
parser.add_argument(
"-input_data_path",
type=str,
default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"),
help="dataset location, please note it should be a CSV file with two"
'columns : "text" and "summary"',
)
parser.add_argument(
"-output_data_path",
type=str,
default=os.path.join(
wd, "dataset", "augmented_datas", "augmented_dataset_output.csv"
),
help="augmented dataset output location",
)
parser.add_argument(
"-columns_mapping",
"--col_map",
type=json.loads,
default={"abstract": "text", "tldr": "summary"},
help="columns names to apply data augmentation on "
"you have to give a key/value pair dict such that "
"{'input_column_name1':'output_column_name1'} by default "
" it is set as {'abstract': 'text', 'tldr':'summary'}, "
"if you don't want to change the column names,"
" please provide a dict such that keys=values ",
)
parser.add_argument("-batch_size", type=int, default=25, help="batch_size")
args = parser.parse_args()
do_backtranslation(**vars(args))