Spaces:
Runtime error
Runtime error
import gradio as gr | |
import hopsworks | |
import joblib | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import shap | |
from sklearn.pipeline import make_pipeline | |
import seaborn as sns | |
feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"] | |
project = hopsworks.login(project="SonyaStern_Lab1") | |
fs = project.get_feature_store() | |
print("trying to dl model") | |
mr = project.get_model_registry() | |
model = mr.get_model("diabetes_model", version=1) | |
model_dir = model.download() | |
model = joblib.load(model_dir + "/diabetes_model.pkl") | |
print("Model downloaded") | |
diabetes_fg = fs.get_feature_group(name="diabetes_gan", version=1) | |
query = diabetes_fg.select_all() | |
# feature_view = fs.get_or_create_feature_view(name="diabetes", | |
feature_view = fs.get_or_create_feature_view( | |
name="diabetes_gan", | |
version=1, | |
description="Read from Diabetes dataset", | |
labels=["diabetes"], | |
query=query, | |
) | |
diabetes_df = pd.DataFrame(diabetes_fg.read()) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.HTML(value="<h1 style='text-align: center;'>Diabetes prediction</h1>") | |
with gr.Row(): | |
with gr.Column(): | |
age_input = gr.Number(label="age") | |
bmi_input = gr.Slider(10, 100, label="bmi", info="Body Mass Index") | |
hba1c_input = gr.Slider( | |
3.5, 9, label="hba1c_level", info="Glycated Haemoglobin" | |
) | |
blood_glucose_input = gr.Slider( | |
80, 300, label="blood_glucose_level", info="Blood Glucose Level" | |
) | |
existent_info_input = gr.Radio( | |
["yes", "no", "Don't know"], | |
label="Do you already know if you have diabetes? (This will not be used for the prediction)", | |
) | |
consent_input = gr.Checkbox( | |
info="I consent that my personal data will be saved and potentially be used for the model training", | |
label="accept", | |
) | |
btn = gr.Button("Submit") | |
with gr.Column(): | |
with gr.Row(): | |
output = gr.Text(label="Model prediction") | |
with gr.Row(): | |
mean_plot = gr.Plot() | |
with gr.Row(): | |
with gr.Accordion("See model explanability", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
waterfall_plot = gr.Plot() | |
with gr.Column(): | |
summary_plot = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
importance_plot = gr.Plot() | |
with gr.Column(): | |
decision_plot = gr.Plot() | |
def submit_inputs( | |
age_input, | |
bmi_input, | |
hba1c_input, | |
blood_glucose_input, | |
existent_info_input, | |
consent_input, | |
): | |
df = pd.DataFrame( | |
[[age_input, bmi_input, hba1c_input, blood_glucose_input]], | |
columns=["age", "bmi", "hba1c_level", "blood_glucose_level"], | |
) | |
res = model.predict(df) | |
mean_for_age = diabetes_df[ | |
(diabetes_df["diabetes"] == 0) & (diabetes_df["age"] == age_input) | |
].mean() | |
print( | |
"your bmi is:", bmi_input, "the mean for ur age is :", mean_for_age["bmi"] | |
) | |
categories = ["BMI", "HbA1c", "Blood Level"] | |
fig, ax = plt.subplots() | |
bar_width = 0.35 | |
indices = np.arange(len(categories)) | |
ax.bar( | |
indices, | |
[ | |
mean_for_age.bmi, | |
mean_for_age.hba1c_level, | |
mean_for_age.blood_glucose_level, | |
], | |
bar_width, | |
label="Reference", | |
color="b", | |
alpha=0.7, | |
) | |
ax.bar( | |
indices + bar_width, | |
[bmi_input, hba1c_input, blood_glucose_input], | |
bar_width, | |
label="User", | |
color="r", | |
alpha=0.7, | |
) | |
ax.legend() | |
ax.set_xlabel("Variables") | |
ax.set_ylabel("Values") | |
ax.set_title("Comparison with average non-diabetic values for your age") | |
ax.set_xticks(indices + bar_width / 2) | |
ax.set_xticklabels(categories) | |
## explainability plots | |
rf_classifier = model.named_steps["randomforestclassifier"] | |
transformer_pipeline = make_pipeline( | |
*[ | |
step | |
for name, step in model.named_steps.items() | |
if name != "randomforestclassifier" | |
] | |
) | |
transformed_df = transformer_pipeline.transform(df) | |
# Generate the SHAP waterfall plot for fig2 | |
explainer = shap.TreeExplainer(rf_classifier) | |
shap_values = explainer.shap_values( | |
transformed_df | |
) # Compute SHAP values directly on the DataFrame | |
predicted_class = rf_classifier.predict(transformed_df)[0] | |
shap_values_for_predicted_class = shap_values[predicted_class] | |
# Select the SHAP values for the first instance and the positive class | |
shap_explanation = shap.Explanation( | |
values=shap_values_for_predicted_class[0], | |
base_values=explainer.expected_value[predicted_class], | |
data=df.iloc[0], | |
feature_names=["age", "bmi", "hba1c", "glucose"], | |
) | |
fig2 = plt.figure(figsize=(3, 3)) # Create a new figure for SHAP plot | |
fig2.tight_layout() | |
plt.gca().set_position((0, 0, 1, 1)) | |
plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot | |
plt.tight_layout() | |
plt.tick_params(axis="y", labelsize=3) | |
shap.waterfall_plot(shap_explanation) | |
fig3 = plt.figure(figsize=(3, 3)) | |
plt.title("SHAP Summary Plot") | |
shap.summary_plot( | |
shap_values, | |
features=transformed_df, | |
feature_names=["age", "bmi", "hba1c", "glucose"], | |
) | |
fig4 = plt.figure(figsize=(4, 3)) | |
feature_importances = rf_classifier.feature_importances_ | |
plt.title("Feature Importances") | |
sns.barplot(x=feature_importances, y=["age", "bmi", "hba1c", "glucose"]) | |
fig5 = plt.figure(figsize=(3, 3)) | |
plt.title("SHAP Interaction Plot") | |
shap.decision_plot( | |
explainer.expected_value[predicted_class], | |
shap_values_for_predicted_class, | |
df.iloc[0], | |
) | |
## save user's data in hopsworks | |
if consent_input == True: | |
user_data_fg = fs.get_or_create_feature_group( | |
name="user_diabetes_data", | |
version=1, | |
primary_key=["age", "bmi", "hba1c_level", "blood_glucose_level"], | |
description="Submitted user data", | |
) | |
user_data_df = df.copy() | |
user_data_df["diabetes"] = existent_info_input | |
user_data_fg.insert(user_data_df) | |
print("inserted new user data to hopsworks", user_data_df) | |
return res, fig, fig2, fig3, fig4, fig5 | |
btn.click( | |
submit_inputs, | |
inputs=[ | |
age_input, | |
bmi_input, | |
hba1c_input, | |
blood_glucose_input, | |
existent_info_input, | |
consent_input, | |
], | |
outputs=[ | |
output, | |
mean_plot, | |
waterfall_plot, | |
summary_plot, | |
importance_plot, | |
decision_plot, | |
], | |
) | |
demo.launch() | |