Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
-
import
|
5 |
from sklearn.model_selection import train_test_split
|
6 |
from sklearn.feature_extraction.text import CountVectorizer
|
7 |
from sklearn.neighbors import KNeighborsClassifier
|
@@ -10,16 +9,13 @@ from sklearn.tree import DecisionTreeClassifier
|
|
10 |
from sklearn.linear_model import LogisticRegression
|
11 |
from sklearn.svm import SVC
|
12 |
from sklearn.metrics import accuracy_score
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
# Read dataset
|
18 |
df = pd.read_csv(r"spam.csv")
|
19 |
|
20 |
-
# Initialize Streamlit app
|
21 |
-
st.title("Identifying Spam and Ham Emails")
|
22 |
-
|
23 |
# Define feature and target variables
|
24 |
x = df["Message"]
|
25 |
y = df["Category"]
|
@@ -40,54 +36,38 @@ models = {
|
|
40 |
"SVM": SVC()
|
41 |
}
|
42 |
|
43 |
-
#
|
44 |
-
model_choice =
|
45 |
-
|
46 |
-
# Train the selected model
|
47 |
obj = models[model_choice]
|
48 |
obj.fit(x_train, y_train)
|
49 |
y_pred = obj.predict(x_test)
|
50 |
accuracy = accuracy_score(y_test, y_pred)
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
st.write(f"Accuracy of {model_choice}: {accuracy:.4f}")
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
def predict_email(email):
|
60 |
-
data = bow.transform([email]).toarray()
|
61 |
-
prediction = obj.predict(data)[0]
|
62 |
-
st.write(f"Prediction: {prediction}")
|
63 |
-
|
64 |
-
if st.button("Predict Email"):
|
65 |
-
if email_input:
|
66 |
-
predict_email(email_input)
|
67 |
-
else:
|
68 |
-
st.write(":red[Please enter an email to classify]")
|
69 |
-
|
70 |
-
# FastAPI app to handle GET requests
|
71 |
-
app = FastAPI()
|
72 |
-
|
73 |
-
@app.get("/predict")
|
74 |
-
def predict_spam(email: str):
|
75 |
"""
|
76 |
This endpoint predicts whether the email is Spam or Ham.
|
77 |
Query parameter: email (str) - The email text to be classified.
|
78 |
"""
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
91 |
|
92 |
-
#
|
93 |
-
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
+
from flask import Flask, request, jsonify
|
4 |
from sklearn.model_selection import train_test_split
|
5 |
from sklearn.feature_extraction.text import CountVectorizer
|
6 |
from sklearn.neighbors import KNeighborsClassifier
|
|
|
9 |
from sklearn.linear_model import LogisticRegression
|
10 |
from sklearn.svm import SVC
|
11 |
from sklearn.metrics import accuracy_score
|
12 |
+
|
13 |
+
# Initialize Flask app
|
14 |
+
app = Flask(__name__)
|
15 |
|
16 |
# Read dataset
|
17 |
df = pd.read_csv(r"spam.csv")
|
18 |
|
|
|
|
|
|
|
19 |
# Define feature and target variables
|
20 |
x = df["Message"]
|
21 |
y = df["Category"]
|
|
|
36 |
"SVM": SVC()
|
37 |
}
|
38 |
|
39 |
+
# Choose and train a model
|
40 |
+
model_choice = "Naive Bayes" # Default model
|
|
|
|
|
41 |
obj = models[model_choice]
|
42 |
obj.fit(x_train, y_train)
|
43 |
y_pred = obj.predict(x_test)
|
44 |
accuracy = accuracy_score(y_test, y_pred)
|
45 |
|
46 |
+
# Print accuracy for initial check
|
47 |
+
print(f"Accuracy of {model_choice}: {accuracy:.4f}")
|
|
|
48 |
|
49 |
+
@app.route('/predict', methods=['GET'])
|
50 |
+
def predict_spam():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
"""
|
52 |
This endpoint predicts whether the email is Spam or Ham.
|
53 |
Query parameter: email (str) - The email text to be classified.
|
54 |
"""
|
55 |
+
email = request.args.get('email')
|
56 |
+
|
57 |
+
if email:
|
58 |
+
data = bow.transform([email]).toarray() # Transform email using the Bag of Words vectorizer
|
59 |
+
prediction = obj.predict(data)[0] # Get the prediction (Spam or Ham)
|
60 |
+
return jsonify({"prediction": prediction}) # Return prediction as JSON
|
61 |
+
else:
|
62 |
+
return jsonify({"error": "Please provide an 'email' query parameter."}), 400
|
63 |
|
64 |
+
@app.route('/accuracy', methods=['GET'])
|
65 |
+
def get_accuracy():
|
66 |
+
"""
|
67 |
+
Endpoint to check the accuracy of the selected model on the test data.
|
68 |
+
"""
|
69 |
+
return jsonify({"accuracy": accuracy})
|
70 |
|
71 |
+
# Run Flask app
|
72 |
+
if __name__ == '__main__':
|
73 |
+
app.run(host='127.0.0.1', port=5001, debug=True)
|