|
import torch |
|
import streamlit as st |
|
from transformers import PegasusForConditionalGeneration, AutoTokenizer |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def do_summary(model_name): |
|
model = PegasusForConditionalGeneration.from_pretrained(model_name) |
|
return model |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def do_tokenize(model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
return tokenizer |
|
|
|
model = do_summary("google/pegasus-cnn_dailymail") |
|
tokenizer = do_tokenize("google/pegasus-cnn_dailymail") |
|
|
|
|
|
def summarize(passage): |
|
txt = " ".join(passage) |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
batch = tokenizer(txt, truncation=True, padding='longest', return_tensors="pt").to(device) |
|
translated = model.generate(**batch) |
|
summy = tokenizer.batch_decode(translated, skip_special_tokens=True) |
|
print("summ end") |
|
return summy |
|
|