Unityraptor commited on
Commit
889dd23
1 Parent(s): 1186b46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -30
app.py CHANGED
@@ -1,33 +1,39 @@
1
- pip install git+https://github.com/PrithivirajDamodaran/Parrot_Paraphraser.git
2
 
3
  import gradio as gr
4
- from parrot import Parrot
5
- import torch
6
- import warnings
7
- warnings.filterwarnings("ignore")
8
-
9
- '''
10
- uncomment to get reproducable paraphrase generations
11
- def random_state(seed):
12
- torch.manual_seed(seed)
13
- if torch.cuda.is_available():
14
- torch.cuda.manual_seed_all(seed)
15
-
16
- random_state(1234)
17
- '''
18
-
19
- #Init models (make sure you init ONLY once if you integrate this to your code)
20
- parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=False)
21
-
22
- phrases = ["Can you recommed some upscale restaurants in Newyork?",
23
- "What are the famous places we should not miss in Russia?"
24
- ]
25
-
26
- for phrase in phrases:
27
- print("-"*100)
28
- print("Input_phrase: ", phrase)
29
- print("-"*100)
30
- para_phrases = parrot.augment(input_phrase=phrase)
31
- for para_phrase in para_phrases:
32
- print(para_phrase)
 
 
 
 
 
 
 
33
 
 
 
1
 
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ device = "cuda"
6
+
7
+ tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
8
+
9
+ model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base").to(device)
10
+
11
+ def paraphrase(
12
+ question,
13
+ num_beams=5,
14
+ num_beam_groups=5,
15
+ num_return_sequences=5,
16
+ repetition_penalty=10.0,
17
+ diversity_penalty=3.0,
18
+ no_repeat_ngram_size=2,
19
+ temperature=0.7,
20
+ max_length=128
21
+ ):
22
+ input_ids = tokenizer(
23
+ f'paraphrase: {question}',
24
+ return_tensors="pt", padding="longest",
25
+ max_length=max_length,
26
+ truncation=True,
27
+ ).input_ids.to(device)
28
+
29
+ outputs = model.generate(
30
+ input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
31
+ num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
32
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
33
+ max_length=max_length, diversity_penalty=diversity_penalty
34
+ )
35
+
36
+ res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
37
+
38
+ return res
39