saad177 commited on
Commit
5d4b558
1 Parent(s): c734fd3

try waterfall plot explainability

Browse files
Files changed (2) hide show
  1. app.py +11 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import joblib
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
7
 
8
  project = hopsworks.login(project="SonyaStern_Lab1")
9
  fs = project.get_feature_store()
@@ -52,6 +53,9 @@ with gr.Blocks() as demo:
52
  with gr.Column():
53
  output = gr.Text(label="Model prediction")
54
  plot = gr.Plot()
 
 
 
55
 
56
  def submit_inputs(
57
  age_input,
@@ -105,6 +109,11 @@ with gr.Blocks() as demo:
105
  ax.set_xticks(indices + bar_width / 2)
106
  ax.set_xticklabels(categories)
107
 
 
 
 
 
 
108
  ## save user's data in hopsworks
109
  if consent_input == True:
110
  user_data_fg = fs.get_or_create_feature_group(
@@ -117,7 +126,7 @@ with gr.Blocks() as demo:
117
  user_data_df["diabetes"] = existent_info_input
118
  user_data_fg.insert(user_data_df)
119
  print("inserted new user data to hopsworks", user_data_df)
120
- return res, fig
121
 
122
  btn.click(
123
  submit_inputs,
@@ -129,7 +138,7 @@ with gr.Blocks() as demo:
129
  existent_info_input,
130
  consent_input,
131
  ],
132
- outputs=[output, plot],
133
  )
134
 
135
  demo.launch()
 
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import shap
8
 
9
  project = hopsworks.login(project="SonyaStern_Lab1")
10
  fs = project.get_feature_store()
 
53
  with gr.Column():
54
  output = gr.Text(label="Model prediction")
55
  plot = gr.Plot()
56
+ with gr.Row():
57
+ with gr.Accordion("See model explanability"):
58
+ waterfall_plot = gr.Plot()
59
 
60
  def submit_inputs(
61
  age_input,
 
109
  ax.set_xticks(indices + bar_width / 2)
110
  ax.set_xticklabels(categories)
111
 
112
+ ## explainability plots
113
+ explainer = shap.Explainer(model)
114
+ shap_values = explainer(df)
115
+ shap_waterfall_plot = shap.plots.waterfall(shap_values[0])
116
+
117
  ## save user's data in hopsworks
118
  if consent_input == True:
119
  user_data_fg = fs.get_or_create_feature_group(
 
126
  user_data_df["diabetes"] = existent_info_input
127
  user_data_fg.insert(user_data_df)
128
  print("inserted new user data to hopsworks", user_data_df)
129
+ return res, fig, shap_waterfall_plot
130
 
131
  btn.click(
132
  submit_inputs,
 
138
  existent_info_input,
139
  consent_input,
140
  ],
141
+ outputs=[output, plot, waterfall_plot],
142
  )
143
 
144
  demo.launch()
requirements.txt CHANGED
@@ -4,4 +4,5 @@ joblib
4
  pandas
5
  scikit-learn==1.1.1
6
  matplotlib
7
- numpy
 
 
4
  pandas
5
  scikit-learn==1.1.1
6
  matplotlib
7
+ numpy
8
+ shap