File size: 1,380 Bytes
28eaf77 f0df2f1 28eaf77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import streamlit as st
import torch
from transformers import(
T5TokenizerFast as T5Tokenizer)
import warnings
warnings.filterwarnings("ignore")
MODEL_NAME = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
device = torch.device('cpu')
model = torch.load('models.pth', map_location=device)
def summarize(text):
text_encoding = tokenizer(
text,
max_length=512,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt")
generated_ids = model.generate(
input_ids=text_encoding["input_ids"],
attention_mask=text_encoding["attention_mask"],
max_length=150,
num_beams=2,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True)
preds = [
tokenizer.decode(gen_id, skip_special_tokens = True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids
]
return "".join(preds)
def main():
"""Text Summarizer app with streamlit"""
st.title("T5 text summarizer with streamlit")
st.subheader("Summarize your 512 words here!")
message = st.text_area("Enter your text", "Type Here")
if st.button("Summarize text"):
summary_results = summarize(message)
st.write(summary_results)
if __name__ == '__main__':
main()
|