|
import streamlit as st |
|
import torch |
|
from transformers import( |
|
T5TokenizerFast as T5Tokenizer) |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
MODEL_NAME = "t5-small" |
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME) |
|
device = torch.device('cpu') |
|
model = torch.load('models.pth', map_location=device) |
|
|
|
|
|
def summarize(text): |
|
text_encoding = tokenizer( |
|
text, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors="pt") |
|
|
|
generated_ids = model.generate( |
|
input_ids=text_encoding["input_ids"], |
|
attention_mask=text_encoding["attention_mask"], |
|
max_length=150, |
|
num_beams=2, |
|
repetition_penalty=2.5, |
|
length_penalty=1.0, |
|
early_stopping=True) |
|
|
|
preds = [ |
|
tokenizer.decode(gen_id, skip_special_tokens = True, clean_up_tokenization_spaces=True) |
|
for gen_id in generated_ids |
|
] |
|
|
|
return "".join(preds) |
|
|
|
|
|
def main(): |
|
"""Text Summarizer app with streamlit""" |
|
st.title("T5 text summarizer with streamlit") |
|
st.subheader("Summarize your 512 words here!") |
|
message = st.text_area("Enter your text", "Type Here") |
|
if st.button("Summarize text"): |
|
summary_results = summarize(message) |
|
st.write(summary_results) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|