Rephrase / Paraphrase.py
Credbox's picture
Create Paraphrase.py
15b6598 verified
raw
history blame
2.52 kB
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
def chunk_text(text, max_length, tokenizer):
"""Split text into chunks of a specified maximum token length."""
tokens = tokenizer.encode(text, truncation=False)
chunks = []
while len(tokens) > max_length:
chunk = tokens[:max_length]
tokens = tokens[max_length:]
chunks.append(chunk)
if tokens:
chunks.append(tokens)
return chunks
def adjust_lengths(paragraph_length):
"""Adjust max_length and min_length based on the input length."""
if paragraph_length < 100:
return 100, 50 # Shorter paragraphs
elif paragraph_length < 500:
return 300, 150 # Medium-length paragraphs
else:
return 600, 300 # Longer paragraphs
def paraphrase_paragraph(paragraph, model_name='google/pegasus-multi_news'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = PegasusTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
# Tokenize the entire paragraph to calculate length
tokens = tokenizer.encode(paragraph, truncation=False)
paragraph_length = len(tokens)
# Adjust max_length and min_length dynamically
max_length, min_length = adjust_lengths(paragraph_length)
# Chunk the paragraph based on the model's token limit
chunks = chunk_text(paragraph, tokenizer.model_max_length, tokenizer)
paraphrased_chunks = []
for chunk in chunks:
# Decode chunk tokens back to text
chunk_t = tokenizer.decode(chunk, skip_special_tokens=True)
# Tokenize the text chunk
inputs = tokenizer(chunk_t, return_tensors='pt', padding=True, truncation=True).to(device)
# Generate paraphrased text
with torch.no_grad(): # Avoid gradient calculations for inference
generated_ids = model.generate(
inputs['input_ids'],
max_length=max_length, # Dynamically adjusted
min_length=min_length, # Dynamically adjusted
num_beams=3,
early_stopping=True
)
paraphrased_chunk = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
paraphrased_chunks.append(paraphrased_chunk)
# Combine all paraphrased chunks
paraphrased_paragraph = ' '.join(paraphrased_chunks)
return paraphrased_paragraph