Credbox commited on
Commit
15b6598
·
verified ·
1 Parent(s): d72def5

Create Paraphrase.py

Browse files
Files changed (1) hide show
  1. Paraphrase.py +63 -0
Paraphrase.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
2
+ import torch
3
+
4
+ def chunk_text(text, max_length, tokenizer):
5
+ """Split text into chunks of a specified maximum token length."""
6
+ tokens = tokenizer.encode(text, truncation=False)
7
+ chunks = []
8
+ while len(tokens) > max_length:
9
+ chunk = tokens[:max_length]
10
+ tokens = tokens[max_length:]
11
+ chunks.append(chunk)
12
+ if tokens:
13
+ chunks.append(tokens)
14
+ return chunks
15
+
16
+ def adjust_lengths(paragraph_length):
17
+ """Adjust max_length and min_length based on the input length."""
18
+ if paragraph_length < 100:
19
+ return 100, 50 # Shorter paragraphs
20
+ elif paragraph_length < 500:
21
+ return 300, 150 # Medium-length paragraphs
22
+ else:
23
+ return 600, 300 # Longer paragraphs
24
+
25
+ def paraphrase_paragraph(paragraph, model_name='google/pegasus-multi_news'):
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
28
+ tokenizer = PegasusTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
29
+
30
+ # Tokenize the entire paragraph to calculate length
31
+ tokens = tokenizer.encode(paragraph, truncation=False)
32
+ paragraph_length = len(tokens)
33
+
34
+ # Adjust max_length and min_length dynamically
35
+ max_length, min_length = adjust_lengths(paragraph_length)
36
+
37
+ # Chunk the paragraph based on the model's token limit
38
+ chunks = chunk_text(paragraph, tokenizer.model_max_length, tokenizer)
39
+
40
+ paraphrased_chunks = []
41
+ for chunk in chunks:
42
+ # Decode chunk tokens back to text
43
+ chunk_t = tokenizer.decode(chunk, skip_special_tokens=True)
44
+ # Tokenize the text chunk
45
+ inputs = tokenizer(chunk_t, return_tensors='pt', padding=True, truncation=True).to(device)
46
+
47
+ # Generate paraphrased text
48
+ with torch.no_grad(): # Avoid gradient calculations for inference
49
+ generated_ids = model.generate(
50
+ inputs['input_ids'],
51
+ max_length=max_length, # Dynamically adjusted
52
+ min_length=min_length, # Dynamically adjusted
53
+ num_beams=3,
54
+ early_stopping=True
55
+ )
56
+
57
+ paraphrased_chunk = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
58
+ paraphrased_chunks.append(paraphrased_chunk)
59
+
60
+ # Combine all paraphrased chunks
61
+ paraphrased_paragraph = ' '.join(paraphrased_chunks)
62
+
63
+ return paraphrased_paragraph