|
from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
|
import torch |
|
|
|
def chunk_text(text, max_length, tokenizer): |
|
"""Split text into chunks of a specified maximum token length.""" |
|
tokens = tokenizer.encode(text, truncation=False) |
|
chunks = [] |
|
while len(tokens) > max_length: |
|
chunk = tokens[:max_length] |
|
tokens = tokens[max_length:] |
|
chunks.append(chunk) |
|
if tokens: |
|
chunks.append(tokens) |
|
return chunks |
|
|
|
def adjust_lengths(paragraph_length): |
|
"""Adjust max_length and min_length based on the input length.""" |
|
if paragraph_length < 100: |
|
return 100, 50 |
|
elif paragraph_length < 500: |
|
return 300, 150 |
|
else: |
|
return 600, 300 |
|
|
|
def paraphrase_paragraph(paragraph, model_name='google/pegasus-multi_news'): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) |
|
tokenizer = PegasusTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True) |
|
|
|
|
|
tokens = tokenizer.encode(paragraph, truncation=False) |
|
paragraph_length = len(tokens) |
|
|
|
|
|
max_length, min_length = adjust_lengths(paragraph_length) |
|
|
|
|
|
chunks = chunk_text(paragraph, tokenizer.model_max_length, tokenizer) |
|
|
|
paraphrased_chunks = [] |
|
for chunk in chunks: |
|
|
|
chunk_t = tokenizer.decode(chunk, skip_special_tokens=True) |
|
|
|
inputs = tokenizer(chunk_t, return_tensors='pt', padding=True, truncation=True).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
inputs['input_ids'], |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=3, |
|
early_stopping=True |
|
) |
|
|
|
paraphrased_chunk = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
paraphrased_chunks.append(paraphrased_chunk) |
|
|
|
|
|
paraphrased_paragraph = ' '.join(paraphrased_chunks) |
|
|
|
return paraphrased_paragraph |