File size: 7,650 Bytes
86508a7 39bf2c4 86508a7 39bf2c4 86508a7 04010a4 2e2da6d 49a82a3 91d6bc4 685360f 73c527b 2f7f132 73c527b 685360f f4f2263 73c527b 879947a 73c527b cd35bd2 73c527b 2f7f132 73c527b f30b579 73c527b 2f7f132 73c527b 2f7f132 73c527b 9afac68 73c527b 9afac68 04010a4 55f0fa6 04010a4 55f0fa6 04010a4 86508a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# load the model from disk
loaded_model = pickle.load(open("db_xgb.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
# Create the main function for server
def main_func(HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income):
new_row = pd.DataFrame.from_dict({'HighBP': HighBP, 'HighChol': HighChol, 'CholCheck': CholCheck, 'BMI': BMI, 'Smoker': Smoker, 'Stroke': Stroke, 'HeartDiseaseorAttack': HeartDiseaseorAttack, 'PhysActivity':PhysActivity, 'Fruits':Fruits, 'Veggies':Veggies, 'HvyAlcoholConsump': HvyAlcoholConsump, 'AnyHealthcare': AnyHealthcare, 'NoDocbcCost': NoDocbcCost, 'GenHlth': GenHlth, 'MentHlth': MentHlth, 'PhysHlth': PhysHlth, 'DiffWalk': DiffWalk, 'Sex': Sex, 'Age': Age, 'Education': Education, 'Income': Income},
orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "**Diabetes Predictor & Interpreter** πͺ"
description1 = """This app takes info from subjects and predicts their diabetes likelihood. Do not use for medical diagnosis."""
description2 = """
To use the app, click on one of the examples, or adjust the values of the factors, and click on Analyze. π€
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
HighBP = gr.Radio(label="Do you have high blood pressure?", choices=["No", "Yes"], default="Yes", description="0 = no high BP, 1 = high BP")
with gr.Column():
HighChol = gr.Radio(label="Do you have high cholesterol?", choices=["No", "Yes"], default="Yes", description="0 = no high cholesterol, 1 = high cholesterol")
with gr.Row():
with gr.Column():
CholCheck = gr.Radio(label="Did you check your cholestorol in the past 5 years?", choices=["No", "Yes"], default="Yes", description="No = not checked in 5 years, Yes = checked in 5 years")
with gr.Column():
BMI = gr.Number(label="BMI", minimum=0, maximum=98, default=1)
with gr.Row():
with gr.Column():
Smoker = gr.Radio(label="Are you a smoker?", choices=["No", "Yes"], default="Yes", description="No = never smoked, Yes = smoked at least 100 cigarettes")
with gr.Row():
with gr.Column():
Stroke = gr.Radio(label="Have you had a stroke?", choices=["No", "Yes"], default="Yes", description="No = never had a stroke, Yes = had a stroke")
with gr.Column():
HeartDiseaseorAttack = gr.Radio(label="Do you have coronary heart disease or myocardial infarction?", choices=["No", "Yes"], default="Yes", description="No = no CHD/MI, Yes = CHD/MI")
with gr.Row():
with gr.Column():
PhysActivity = gr.Radio(label="Did you partake in physical activity in the past 30 days?", choices=["No", "Yes"], default="Yes", description="No = no activity, Yes = active")
with gr.Row():
with gr.Column():
Fruits = gr.Radio(label="Do you consume fruit 1 or more times per day?", choices=["No", "Yes"], default="Yes", description="No = less than daily, Yes = daily")
with gr.Column():
Veggies = gr.Radio(label="Do you consume vegetables 1 or more times per day?", choices=["No", "Yes"], default="Yes", description="No = less than daily, Yes = daily")
with gr.Column():
HvyAlcoholConsump = gr.Radio(label="Do you drink often? (adult men having more than 14 drinks/week and adult women having more than 7 drinks/week)", choices=["No", "Yes"], default="Yes", description="No = not heavy drinker, Yes = heavy drinker")
with gr.Row():
with gr.Column():
AnyHealthcare = gr.Radio(label="Do you have any kind of health care coverage? (e.g., health insurance, prepaid plans such as HMO)", choices=["No", "Yes"], default="Yes", description="No = no coverage, Yes = coverage")
with gr.Column():
NoDocbcCost = gr.Radio(label="Was there a time in the past 12 months when you needed to see a doctor but could not because of cost?", choices=["No", "Yes"], default="Yes", description="No = no barrier, Yes = cost barrier")
with gr.Row():
GenHlth = gr.Slider(label="In general, rank your health on a scale: 1(excellent)-5(poor)", minimum=1, maximum=5, default=1, step=1, description="1 = excellent, 5 = poor")
with gr.Row():
MentHlth = gr.Number(label="How many days in the past 30 days did you have poor mental health?", minimum=0, maximum=30, default=1, description="Days not good out of last 30")
with gr.Row():
PhysHlth = gr.Number(label="How many days in the past 30 days did you have poor physical health?", minimum=0, maximum=30, default=1, description="Days not good out of last 30")
with gr.Row():
DiffWalk = gr.Radio(label="Do you have serious difficulty walking or climbing stairs?", choices=["No", "Yes"], default="Yes", description="No = no difficulty, Yes = difficulty")
with gr.Row():
Sex = gr.Radio(label="Sex", choices=["Female", "Male"], default="Male", description="Female or Male")
with gr.Row():
Age = gr.Number(label="Age", minimum=1, maximum=100, default=1)
with gr.Row():
Education = gr.Dropdown(label="Education Level", choices=["Never attended school", "Grades 1-8", "Grades 9-11", "Grade 12 or GED", "College 1-3 years", "College 4+ years"], default="Never attended school", description="Education level")
with gr.Row():
Income = gr.Dropdown(label="Income Level", choices=["< $10,000", "$10,000 - $24,999", "$25,000 - $49,999", "$50,000 - $74,999", "$75,000 or more"], default="< $10,000", description="Income level")
with gr.Column(visible=True) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
submit_btn = gr.Button("Analyze")
submit_btn.click(
main_func,
[HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income],
[label,local_plot], api_name="Diabetes_Predictor"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([[0,0,1,0,22,0,0,0,1,1,1,0,0,1,3,25,23,1,1,21,5,3], [1,1,1,1,30,1,1,1,0,0,0,1,1,0,2,20,23,0,0,21,3,2]], [HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income], [label,local_plot], main_func, cache_examples=True)
demo.launch()
|