from transformers import PegasusForConditionalGeneration, PegasusTokenizer import torch def chunk_text(text, max_length, tokenizer): """Split text into chunks of a specified maximum token length.""" tokens = tokenizer.encode(text, truncation=False) chunks = [] while len(tokens) > max_length: chunk = tokens[:max_length] tokens = tokens[max_length:] chunks.append(chunk) if tokens: chunks.append(tokens) return chunks def adjust_lengths(paragraph_length): """Adjust max_length and min_length based on the input length.""" if paragraph_length < 100: return 100, 50 # Shorter paragraphs elif paragraph_length < 500: return 300, 150 # Medium-length paragraphs else: return 600, 300 # Longer paragraphs def paraphrase_paragraph(paragraph, model_name='google/pegasus-multi_news'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) tokenizer = PegasusTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True) # Tokenize the entire paragraph to calculate length tokens = tokenizer.encode(paragraph, truncation=False) paragraph_length = len(tokens) # Adjust max_length and min_length dynamically max_length, min_length = adjust_lengths(paragraph_length) # Chunk the paragraph based on the model's token limit chunks = chunk_text(paragraph, tokenizer.model_max_length, tokenizer) paraphrased_chunks = [] for chunk in chunks: # Decode chunk tokens back to text chunk_t = tokenizer.decode(chunk, skip_special_tokens=True) # Tokenize the text chunk inputs = tokenizer(chunk_t, return_tensors='pt', padding=True, truncation=True).to(device) # Generate paraphrased text with torch.no_grad(): # Avoid gradient calculations for inference generated_ids = model.generate( inputs['input_ids'], max_length=max_length, # Dynamically adjusted min_length=min_length, # Dynamically adjusted num_beams=3, early_stopping=True ) paraphrased_chunk = tokenizer.decode(generated_ids[0], skip_special_tokens=True) paraphrased_chunks.append(paraphrased_chunk) # Combine all paraphrased chunks paraphrased_paragraph = ' '.join(paraphrased_chunks) return paraphrased_paragraph