michalisG
commited on
Commit
•
acea606
1
Parent(s):
74f37f1
Add application file-1
Browse files- app.py +65 -0
- config.json +35 -0
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import shap
|
3 |
+
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
import joblib
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
#
|
8 |
+
|
9 |
+
from utils.data_processor import DataProcessor
|
10 |
+
from utils.model_predictor import ModelPredictor
|
11 |
+
from utils.user_input_features_collector import UserInputDataCollector
|
12 |
+
|
13 |
+
model = joblib.load('resources/model.joblib')
|
14 |
+
categorical_names = joblib.load('resources/categorical_names.pkl')
|
15 |
+
target_labels = joblib.load('resources/target_labels.pkl')
|
16 |
+
selected_features = []
|
17 |
+
|
18 |
+
shap_explainer = shap.TreeExplainer(model.named_steps['RandomForestClassifier'])
|
19 |
+
data_processor = DataProcessor(model, categorical_names, selected_features)
|
20 |
+
predictor = ModelPredictor(model)
|
21 |
+
|
22 |
+
st.write("### Enter Patient Information for Diagnosis Prediction")
|
23 |
+
data = UserInputDataCollector.user_input_features()
|
24 |
+
user_input = pd.DataFrame(data, index=[0])
|
25 |
+
|
26 |
+
st.write("#### Patient Data")
|
27 |
+
st.write(user_input)
|
28 |
+
|
29 |
+
# In your Streamlit app, where you handle the "Predict" button:
|
30 |
+
if st.button("Predict"):
|
31 |
+
prediction, probabilities = predictor.predict(user_input)
|
32 |
+
col1, col2 = st.columns(2)
|
33 |
+
labels_map = {0: "Transplant/Death", 1: "Survive"}
|
34 |
+
label = labels_map.get(int(np.argmax(probabilities)))
|
35 |
+
|
36 |
+
# with col1:
|
37 |
+
# st.subheader("Prediction")
|
38 |
+
# st.write(label)
|
39 |
+
#
|
40 |
+
# with col2:
|
41 |
+
st.subheader("Prediction Probabilities")
|
42 |
+
# Create a DataFrame for the probabilities to display them in a more readable format
|
43 |
+
proba_df = pd.DataFrame(probabilities, columns=labels_map.values())
|
44 |
+
st.dataframe(proba_df) # Using st.dataframe to make it more interact
|
45 |
+
|
46 |
+
i = 0
|
47 |
+
preprocessed_input = data_processor.shap_and_eli5_custom_format(user_input)
|
48 |
+
shap_values = shap_explainer.shap_values(preprocessed_input)
|
49 |
+
# np.argmax(probabilities)
|
50 |
+
shap_explanation = shap.Explanation(values=shap_values[np.argmax(probabilities)][0, :],
|
51 |
+
base_values=shap_explainer.expected_value[np.argmax(probabilities)],
|
52 |
+
data=user_input.iloc[0, :],
|
53 |
+
feature_names=user_input.columns.tolist())
|
54 |
+
|
55 |
+
# Generate the SHAP waterfall plot
|
56 |
+
shap.plots.waterfall(shap_explanation, max_display=len(user_input.columns.tolist()), show=False)
|
57 |
+
# After generating the SHAP plot, grab the current figure
|
58 |
+
fig = plt.gcf()
|
59 |
+
fig.set_size_inches(10, 7, forward=True)
|
60 |
+
# Optionally, adjust the plot title or other properties here
|
61 |
+
fig.suptitle(f'Prediction: {label}', fontsize=20, y=1.05)
|
62 |
+
# Display the figure in Streamlit, passing it explicitly to ensure thread safety
|
63 |
+
st.pyplot(fig)
|
64 |
+
# Reset the default plot size if necessary
|
65 |
+
plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']
|
config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "RandomForestClassifier",
|
3 |
+
"expected_features": [
|
4 |
+
"age",
|
5 |
+
"sex",
|
6 |
+
"serum_bilirubin",
|
7 |
+
"serum_cholesterol",
|
8 |
+
"albumin",
|
9 |
+
"alkaline_phosphatase",
|
10 |
+
"SGOT",
|
11 |
+
"platelets",
|
12 |
+
"prothrombin_time"
|
13 |
+
],
|
14 |
+
"categorical_features": [
|
15 |
+
"drug",
|
16 |
+
"sex",
|
17 |
+
"presence_of_ascites",
|
18 |
+
"presence_of_hepatomegaly",
|
19 |
+
"presence_of_spiders",
|
20 |
+
"presence_of_edema"
|
21 |
+
],
|
22 |
+
"model_parameters": {
|
23 |
+
"criterion": "entropy",
|
24 |
+
"max_features": 0.1,
|
25 |
+
"min_samples_split": 8,
|
26 |
+
"min_samples_leaf": 6,
|
27 |
+
"bootstrap": true
|
28 |
+
},
|
29 |
+
"version": "1.0",
|
30 |
+
"preprocessing": {
|
31 |
+
"numerical": "median imputation and scaling",
|
32 |
+
"categorical": "one-hot encoding",
|
33 |
+
"ordinal": "label encoding"
|
34 |
+
}
|
35 |
+
}
|