Lauraayu commited on
Commit
e1fb94e
1 Parent(s): 35358bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -21
app.py CHANGED
@@ -1,26 +1,43 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
3
 
4
- def main():
5
- # Define pipelines
6
- summarizer_ntg = pipeline(model="mrm8488/t5-base-finetuned-summarize-news")
7
- model = pipeline(model="Lauraayu/News_Classi_Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Streamlit application title
10
- st.title("News Article Classifier")
11
- st.write("Enter a news article text to get its category:")
12
 
13
- # Text input for user to enter the news article text
14
- text = st.text_area("Enter the news article text here:")
 
 
15
 
16
- # Perform summarization and classification when the user clicks the "Classify" button
17
- if st.button("Classify"):
18
- # Perform text summarization
19
- summary = summarizer_ntg(text)[0]
20
-
21
- # Perform classification
22
- output = model(summary)
23
- category = output[0]["label"]
24
-
25
- # Display the summary and classification result
26
- st.write("Category:", category)
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
 
5
+ # Define the summarization pipeline
6
+ summarizer_ntg = pipeline("summarization", model="mrm8488/t5-base-finetuned-summarize-news")
7
+
8
+ # Load the tokenizer and model for classification
9
+ tokenizer_bb = AutoTokenizer.from_pretrained("your-username/your-model-name")
10
+ model_bb = AutoModelForSequenceClassification.from_pretrained("your-username/your-model-name")
11
+
12
+ # Streamlit application title
13
+ st.title("News Article Summarizer and Classifier")
14
+ st.write("Enter a news article text to get its summary and category.")
15
+
16
+ # Text input for user to enter the news article text
17
+ text = st.text_area("Enter the news article text here:")
18
+
19
+ # Perform summarization and classification when the user clicks the "Classify" button
20
+ if st.button("Classify"):
21
+ # Perform text summarization
22
+ summary = summarizer_ntg(text)[0]['summary_text']
23
 
24
+ # Tokenize the summarized text
25
+ inputs = tokenizer_bb(summary, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
26
 
27
+ # Move inputs and model to the same device (GPU or CPU)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ inputs = {k: v.to(device) for k, v in inputs.items()}
30
+ model_bb.to(device)
31
 
32
+ # Perform text classification
33
+ with torch.no_grad():
34
+ outputs = model_bb(**inputs)
35
+
36
+ # Get the predicted label
37
+ predicted_label_id = torch.argmax(outputs.logits, dim=-1).item()
38
+ label_mapping = model_bb.config.id2label
39
+ predicted_label = label_mapping[predicted_label_id]
40
+
41
+ # Display the summary and classification result
42
+ st.write("Summary:", summary)
43
+ st.write("Category:", predicted_label)