attaelahi commited on
Commit
5f27c76
·
1 Parent(s): 0f26191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -22
app.py CHANGED
@@ -1,31 +1,106 @@
1
- # Filename: app.py
2
  import streamlit as st
3
- from transformers import pipeline
 
 
 
4
 
5
- # Load a different text classification model for spam detection
6
- classifier = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
7
 
8
- def main():
9
- st.title("Spam Detection App")
 
10
 
11
- # Text input for the user to enter a message
12
- user_input = st.text_input("Enter a message:")
13
 
14
- if st.button("Check for Spam"):
15
- if user_input:
16
- # Use the loaded model to classify the user's input
17
- result = classifier(user_input)[0]
18
 
19
- # Display the result
20
- st.write(f"**Result:** {result['label']} (Confidence: {result['score']:.2%})")
 
21
 
22
- # Show a message based on the classification
23
- if result['label'] == 'spam':
24
- st.error("This message is classified as spam.")
25
- else:
26
- st.success("This message is not spam.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
28
- st.warning("Please enter a message before checking for spam.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if __name__ == "__main__":
31
- main()
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ import torch
5
+ import os
6
 
7
+ # Specify the directory where you'll save your fine-tuned model
8
+ FINE_TUNED_MODEL_DIR = "./fine_tuned_sms_spam_model"
9
 
10
+ # Load model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained("wesleyacheng/sms-spam-classification-with-bert")
12
+ model = AutoModelForSequenceClassification.from_pretrained("wesleyacheng/sms-spam-classification-with-bert")
13
 
14
+ # Create a Streamlit app
15
+ st.title("SMS Spam Classification")
16
 
17
+ def classify_spam_or_ham(text):
18
+ # Tokenize input text
19
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
 
20
 
21
+ # Perform classification
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
 
25
+ # Get the predicted label
26
+ predicted_label = "Spam" if outputs.logits[0][1] > outputs.logits[0][0] else "Not-Spam"
27
+
28
+ return predicted_label
29
+
30
+ st.write("Single SMS Example:")
31
+
32
+
33
+ # Function to classify a single SMS
34
+ def classify_single_sms(text):
35
+ if isinstance(text, str): # Check if text is a string
36
+ prediction = classify_spam_or_ham(text)
37
+ st.write(f"SMS: {text}")
38
+ st.write(f"Prediction: {prediction}")
39
+ st.write("--------")
40
+ else:
41
+ st.warning("Skipping non-text data.")
42
+
43
+ # Main Streamlit code for CSV file upload
44
+ st.sidebar.header("Upload CSV File")
45
+ uploaded_file = st.sidebar.file_uploader("Upload a CSV file with SMS messages:", type=["csv"])
46
+
47
+ if uploaded_file is not None:
48
+ st.sidebar.write("Classifying SMS messages in the uploaded file...")
49
+ try:
50
+ df = pd.read_csv(uploaded_file, encoding='latin1') # Specify the appropriate encoding
51
+ except UnicodeDecodeError:
52
+ st.sidebar.error("Error: Unable to decode the CSV file. Please make sure it is in the correct encoding.")
53
+ else:
54
+ # Allow the user to select the column containing SMS messages
55
+ selected_column = st.sidebar.selectbox("Select the SMS column:", df.columns)
56
+
57
+ if df[selected_column].dtype == "object":
58
+ st.write("Classifications:")
59
+ for sms_text in df[selected_column]:
60
+ classify_single_sms(sms_text)
61
  else:
62
+ st.sidebar.error("Selected column does not contain text data and cannot be tokenized.")
63
+
64
+ st.sidebar.write("Classification completed!")
65
+
66
+ st.sidebar.write("Or classify a single SMS:")
67
+ user_input = st.sidebar.text_area("Enter an SMS message:")
68
+ if st.sidebar.button("Classify"):
69
+ if user_input:
70
+ classify_single_sms(user_input)
71
+ else:
72
+ st.sidebar.warning("Please enter an SMS message.")
73
+
74
+ st.write("Or fine-tune the model:")
75
+ if st.button("Fine-Tune Model"):
76
+ if uploaded_file is not None and selected_column and df[selected_column].dtype == "object":
77
+ # Use the data from the uploaded CSV file as the fine-tuning dataset
78
+ custom_dataset = df[selected_column]
79
+
80
+ # Specify your fine-tuning training arguments
81
+ training_args = TrainingArguments(
82
+ output_dir=FINE_TUNED_MODEL_DIR,
83
+ overwrite_output_dir=True,
84
+ per_device_train_batch_size=8,
85
+ num_train_epochs=3,
86
+ )
87
+
88
+ trainer = Trainer(
89
+ model=model,
90
+ args=training_args,
91
+ train_dataset=custom_dataset,
92
+ )
93
+
94
+ # Fine-tune the model
95
+ trainer.train()
96
 
97
+ # Save the fine-tuned model
98
+ model.save_pretrained(FINE_TUNED_MODEL_DIR)
99
+ tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
100
+ st.write("Model has been fine-tuned and saved.")
101
+ elif not uploaded_file:
102
+ st.warning("Please upload a CSV file before fine-tuning.")
103
+ elif not selected_column:
104
+ st.warning("Please select the SMS column before fine-tuning.")
105
+ else:
106
+ st.warning("The selected column does not contain text data and cannot be used for fine-tuning.")