Pamudu13 commited on
Commit
7e8cccb
·
verified ·
1 Parent(s): cd3da46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -47
app.py CHANGED
@@ -1,7 +1,6 @@
1
- import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- import sklearn
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.feature_extraction.text import CountVectorizer
7
  from sklearn.neighbors import KNeighborsClassifier
@@ -10,16 +9,13 @@ from sklearn.tree import DecisionTreeClassifier
10
  from sklearn.linear_model import LogisticRegression
11
  from sklearn.svm import SVC
12
  from sklearn.metrics import accuracy_score
13
- from fastapi import FastAPI
14
- from fastapi.responses import JSONResponse
15
- import threading
16
 
17
  # Read dataset
18
  df = pd.read_csv(r"spam.csv")
19
 
20
- # Initialize Streamlit app
21
- st.title("Identifying Spam and Ham Emails")
22
-
23
  # Define feature and target variables
24
  x = df["Message"]
25
  y = df["Category"]
@@ -40,54 +36,38 @@ models = {
40
  "SVM": SVC()
41
  }
42
 
43
- # Model selection
44
- model_choice = st.selectbox("Choose a Classification Algorithm", list(models.keys()))
45
-
46
- # Train the selected model
47
  obj = models[model_choice]
48
  obj.fit(x_train, y_train)
49
  y_pred = obj.predict(x_test)
50
  accuracy = accuracy_score(y_test, y_pred)
51
 
52
- # Display accuracy
53
- if st.button("Show Accuracy"):
54
- st.write(f"Accuracy of {model_choice}: {accuracy:.4f}")
55
 
56
- # Email input and prediction function
57
- email_input = st.text_input("Enter an Email for Prediction")
58
-
59
- def predict_email(email):
60
- data = bow.transform([email]).toarray()
61
- prediction = obj.predict(data)[0]
62
- st.write(f"Prediction: {prediction}")
63
-
64
- if st.button("Predict Email"):
65
- if email_input:
66
- predict_email(email_input)
67
- else:
68
- st.write(":red[Please enter an email to classify]")
69
-
70
- # FastAPI app to handle GET requests
71
- app = FastAPI()
72
-
73
- @app.get("/predict")
74
- def predict_spam(email: str):
75
  """
76
  This endpoint predicts whether the email is Spam or Ham.
77
  Query parameter: email (str) - The email text to be classified.
78
  """
79
- data = bow.transform([email]).toarray()
80
- prediction = obj.predict(data)[0]
81
- return JSONResponse(content={"prediction": prediction})
82
-
83
- # Running FastAPI in a separate thread to work alongside Streamlit
84
- def run_api():
85
- import uvicorn
86
- uvicorn.run(app, host="0.0.0.0", port=8000)
87
 
88
- # Start FastAPI in a separate thread
89
- api_thread = threading.Thread(target=run_api, daemon=True)
90
- api_thread.start()
 
 
 
91
 
92
- # You can also check API response using the link below:
93
- # http://localhost:8000/predict/?email=Your_email_text_here
 
 
 
1
  import pandas as pd
2
  import numpy as np
3
+ from flask import Flask, request, jsonify
4
  from sklearn.model_selection import train_test_split
5
  from sklearn.feature_extraction.text import CountVectorizer
6
  from sklearn.neighbors import KNeighborsClassifier
 
9
  from sklearn.linear_model import LogisticRegression
10
  from sklearn.svm import SVC
11
  from sklearn.metrics import accuracy_score
12
+
13
+ # Initialize Flask app
14
+ app = Flask(__name__)
15
 
16
  # Read dataset
17
  df = pd.read_csv(r"spam.csv")
18
 
 
 
 
19
  # Define feature and target variables
20
  x = df["Message"]
21
  y = df["Category"]
 
36
  "SVM": SVC()
37
  }
38
 
39
+ # Choose and train a model
40
+ model_choice = "Naive Bayes" # Default model
 
 
41
  obj = models[model_choice]
42
  obj.fit(x_train, y_train)
43
  y_pred = obj.predict(x_test)
44
  accuracy = accuracy_score(y_test, y_pred)
45
 
46
+ # Print accuracy for initial check
47
+ print(f"Accuracy of {model_choice}: {accuracy:.4f}")
 
48
 
49
+ @app.route('/predict', methods=['GET'])
50
+ def predict_spam():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
  This endpoint predicts whether the email is Spam or Ham.
53
  Query parameter: email (str) - The email text to be classified.
54
  """
55
+ email = request.args.get('email')
56
+
57
+ if email:
58
+ data = bow.transform([email]).toarray() # Transform email using the Bag of Words vectorizer
59
+ prediction = obj.predict(data)[0] # Get the prediction (Spam or Ham)
60
+ return jsonify({"prediction": prediction}) # Return prediction as JSON
61
+ else:
62
+ return jsonify({"error": "Please provide an 'email' query parameter."}), 400
63
 
64
+ @app.route('/accuracy', methods=['GET'])
65
+ def get_accuracy():
66
+ """
67
+ Endpoint to check the accuracy of the selected model on the test data.
68
+ """
69
+ return jsonify({"accuracy": accuracy})
70
 
71
+ # Run Flask app
72
+ if __name__ == '__main__':
73
+ app.run(host='127.0.0.1', port=5001, debug=True)