Spaces:
Sleeping
Sleeping
import random | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
import nltk | |
from data_builder import load_data, save_data | |
from model import from_pretrained | |
class T5Paraphraser: | |
def __init__(self, args): | |
self.device = args.device | |
self.tokenizer = from_pretrained(AutoTokenizer, args.t5_model_name, {}, args.cache_dir) | |
self.model = from_pretrained(AutoModelForSeq2SeqLM, args.t5_model_name, {}, args.cache_dir) | |
self.model = self.model.to(args.device) | |
self.model.eval() | |
def paraphrase(self, sents): | |
parabatch = ["paraphrase: " + sent + " </s>" for sent in sents] | |
encoding = self.tokenizer(parabatch, padding=True, return_tensors="pt") | |
input_ids, attention_masks = encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) | |
outputs = self.model.generate( | |
input_ids=input_ids, attention_mask=attention_masks, | |
max_length=256, | |
do_sample=True, | |
top_k=200, | |
top_p=0.95, | |
early_stopping=True, | |
num_return_sequences=1 | |
) | |
assert len(sents) == len(outputs) | |
results = [] | |
for output, sent in zip(outputs, sents): | |
line = self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
line = line.strip() | |
line = line if len(line) > 0 else sent | |
results.append(line) | |
return results | |
class RandomParaphraser: | |
def __init__(self, args): | |
self.device = args.device | |
def paraphrase(self, sents): | |
results = [] | |
for sent in sents: | |
words = sent.split() | |
if len(words) > 20: | |
idx = random.randint(0, len(words) - 2) | |
words[idx], words[idx+1] = words[idx+1], words[idx] | |
results.append(' '.join(words)) | |
return results | |
def generate_data(args): | |
data = load_data(args.dataset_file) | |
originals = data['original'] | |
samples = data['sampled'] | |
print(f"Total number of samples: {len(samples)}") | |
print(f"Average number of words: {np.mean([len(x.split()) for x in samples])}") | |
if args.do_random_para: | |
print(f'Using random paraphraser.') | |
paraphraser = RandomParaphraser(args) | |
else: | |
print(f'Loading model {args.t5_model_name}...') | |
paraphraser = T5Paraphraser(args) | |
new_samples = [] | |
for sample in tqdm(samples): | |
lines = sample.split('\n') | |
new_lines = [] | |
for line in lines: | |
line = line.strip() | |
if len(line) == 0: | |
new_lines.append(line) | |
else: | |
sents = nltk.sent_tokenize(line) | |
new_sents = paraphraser.paraphrase(sents) | |
new_lines.append(' '.join(new_sents)) | |
new_samples.append('\n'.join(new_lines)) | |
new_data = {'original': originals, 'sampled': new_samples} | |
save_data(args.output_file, args, new_data) | |
if __name__ == '__main__': | |
import argparse | |
from tqdm import tqdm | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--output_file', type=str, default="./exp_test/results/xsum_gpt2") | |
parser.add_argument('--dataset', type=str, default="xsum") | |
parser.add_argument('--dataset_file', type=str, default="./exp_test/data/xsum_gpt2") | |
parser.add_argument('--t5_model_name', type=str, default="Vamsi/T5_Paraphrase_Paws") | |
parser.add_argument('--paraphraser', type=str, default="t5", choices=["t5", "random"]) | |
parser.add_argument('--seed', type=int, default=0) | |
parser.add_argument('--device', type=str, default="cuda") | |
parser.add_argument('--cache_dir', type=str, default="../cache") | |
args = parser.parse_args() | |
torch.manual_seed(args.seed) | |
np.random.seed(args.seed) | |
import nltk | |
nltk.download('punkt') | |
generate_data(args) | |