Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import torch
|
2 |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
|
4 |
-
model_name = '
|
5 |
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
-
tokenizer =
|
7 |
-
model =
|
8 |
|
9 |
def get_response(input_text,num_return_sequences):
|
10 |
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
|
|
|
1 |
import torch
|
2 |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
|
4 |
+
model_name = 'prithivida/parrot_paraphraser_on_T5'
|
5 |
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
+
tokenizer = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
7 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(torch_device)
|
8 |
|
9 |
def get_response(input_text,num_return_sequences):
|
10 |
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
|