File size: 1,380 Bytes
28eaf77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0df2f1
28eaf77
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import torch
from transformers import(
    T5TokenizerFast as T5Tokenizer)
import warnings
warnings.filterwarnings("ignore")

MODEL_NAME = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
device = torch.device('cpu')
model = torch.load('models.pth', map_location=device)


def summarize(text):
    text_encoding = tokenizer(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt")

    generated_ids = model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=150,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True)

    preds = [
        tokenizer.decode(gen_id, skip_special_tokens = True, clean_up_tokenization_spaces=True)
        for gen_id in generated_ids
    ]

    return "".join(preds)


def main():
    """Text Summarizer app with streamlit"""
    st.title("T5 text summarizer with streamlit")
    st.subheader("Summarize your 512 words here!")
    message = st.text_area("Enter your text", "Type Here")
    if st.button("Summarize text"):
        summary_results = summarize(message)
        st.write(summary_results)


if __name__ == '__main__':
    main()