Lauraayu commited on
Commit
5888fc0
·
verified ·
1 Parent(s): 203fe06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -1,22 +1,29 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
 
5
- # Define the summarization pipeline
6
- summarizer_ntg = pipeline("text2text-generation", model="mrm8488/t5-base-finetuned-summarize-news")
 
7
 
 
 
 
 
 
 
8
 
9
- # Streamlit application title
10
- st.title("News Article Summarizer and Classifier")
11
- st.write("Enter a news article text to get its summary and category.")
12
 
13
- # Text input for user to enter the news article text
14
- text = st.text_area("Enter the news article text here:")
15
 
16
- # Perform summarization and classification when the user clicks the "Classify" button
17
- if st.button("Classify"):
18
- # Perform text summarization
19
- summary = summarizer_ntg(text)[0]['summary_text']
20
-
21
- # Display the summary and classification result
22
- st.write("Summary:", summary)
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelWithLMHead
 
3
 
4
+ # 加载模型和分词器
5
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
6
+ model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
7
 
8
+ # 定义摘要函数
9
+ def summarize(text, max_length=150):
10
+ input_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
11
+ generated_ids = model.generate(input_ids=input_ids, num_beams=2, max_length=max_length, repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)
12
+ preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
13
+ return preds[0]
14
 
15
+ # Streamlit 应用程序界面
16
+ st.title("News Summarization App")
17
+ st.write("Enter the news article text below to generate a summary.")
18
 
19
+ article = st.text_area("News Article", height=300)
20
+ max_len = st.slider("Max Length of Summary", min_value=50, max_value=300, value=150)
21
 
22
+ if st.button("Summarize"):
23
+ if article:
24
+ with st.spinner("Generating summary..."):
25
+ summary = summarize(article, max_length=max_len)
26
+ st.write("**Summary:**")
27
+ st.write(summary)
28
+ else:
29
+ st.error("Please enter some text to summarize.")