Group_5 / app.py
avc3px's picture
Update app.py
216022c verified
raw
history blame
7.64 kB
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# load the model from disk
loaded_model = pickle.load(open("cdc_diabetes_health_indicators.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
age_d = {"18-24":1,"25-29":2,"30-34":3,"35-39":4,"40-44":5,"45-49":6,"50-54":7,"55-59":8,"60-64":9,"65-69":10,"70-74":11,"75-79":12,"80 and older":13}
education_d = {"Never attended school or only kindergarten":1,"Grades 1 through 8 (Elementary)":2,"Grades 9 through 11 (Some high school)":3,"Grade 12 or GED (High school graduate)":4,"College 1 year to 3 years (Some college or technical school)":5,"College 4 years or more (College graduated)":6}
income_d = {"Less than $10,000":1,"Less than $16,250":2,"Less than $22,500":3,"Less than $28,750":4,"Less than $35,000":5,"Less than $48,500":6,"Less than $61,500":7,"$75,000 or more":8}
# Create the main function for server
def main_func(HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income):
new_row = pd.DataFrame.from_dict({
'HighBP':HighBP,'HighChol':HighChol,'CholCheck':CholCheck,
'BMI':BMI, 'Smoker':Smoker,'Stroke':Stroke,'HeartDiseaseorAttack':HeartDiseaseorAttack,
'PhysActivity':PhysActivity,'Fruits':Fruits,'Veggies':Veggies,'HvyAlcoholConsump':HvyAlcoholConsump,
'AnyHealthcare':AnyHealthcare, 'NoDocbcCost':NoDocbcCost, 'GenHlth':GenHlth, 'MentHlth': MentHlth,
'PhysHlth':PhysHlth, 'DiffWalk':DiffWalk, 'Sex':Sex, 'Age':Age, 'Education':Education, 'Income':Income},
orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
# plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
# plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "**Diabetes Predictor & Interpreter** "
description1 = """This app takes information from participants and predicts their diabetes likelihood. Do not use for medical diagnosis."""
description2 = """
To use the app, pick the most applicable option for you by ticking on the circles, adjusting the values, and clicking on the boxes to show the options. After completion, click on Analyze. There are two examples below to see how it works.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
HighBP = gr. Radio (["No", "Yes"], label = "Do you have high blood pressure?", type = "index")
HighChol = gr.Radio(["No", "Yes"], label = "Do you have high cholesterol?", type = "index")
CholCheck = gr.Radio(["No", "Yes"], label = "Did you have your cholesterol check within 5 years?", type = "index")
BMI = gr.Slider(label="BMI", minimum=12, maximum=98, value=12, step=1)
Smoker = gr.Radio(["No", "Yes"], label = "Have you smoked at least 100 cigarettes in your entire life? Note: 5 packs = 100 cigarettes", type = "index")
Stroke = gr.Radio(["No", "Yes"], label = "Did you ever had a stroke?", type = "index")
HeartDiseaseorAttack = gr.Radio(["No", "Yes"], label = "Do you have either a Coronary Heart Disease(CHD) or a Myocardial Infarction(heart attack)?", type = "index")
PhysActivity = gr.Radio(["No", "Yes"], label = "Do you do any physical activity in the past 30 days in (not including your job)?", type = "index")
Fruits = gr.Radio(["No", "Yes"], label = "Do you eat fruits once or more times per day?", type = "index")
Veggies = gr.Radio(["No", "Yes"], label = "Do you eat vegetables once or more times per day?", type = "index")
HvyAlcoholConsump = gr.Radio(["No", "Yes"], label = "Are you a heavy drinker? Note: Adult men = more than 14 drinks per week & Adult Women = more than 7 drinks per week", type = "index")
AnyHealthcare = gr.Radio(["No", "Yes"], label = "Do you have any kind of healthcare coverage, including health insurance, prepaid plans such as HMO, etc.?", type = "index")
NoDocbcCost = gr.Radio(["No", "Yes"], label = "Was there a time in the past 12 months when you needed to see a doctor but could not because of cost?", type = "index")
GenHlth = gr.Slider(label="How would rate your general health? Note: 1 = excellent, 2 = very good, 3 = good, 4 = fair, 5 = poor", minimum=1, maximum=5, value=1, step=1)
MentHlth = gr.Slider(label="How many days was your mental health not good in the past 30 days? This includes stress, depression, and problems with emotions.", minimum=0, maximum=30, value=0, step=1)
PhysHlth = gr.Slider(label="How many days was your physical health not good in the past 30 days? This includes physical illness and injuries.", minimum=0, maximum=30, value=0, step=1)
DiffWalk = gr.Radio(["No", "Yes"], label = "Do you have serious difficulty walking or climbing stairs?", type = "index")
Sex = gr.Radio(["Male", "Female"], label = "Sex", type = "index")
Age = gr.Dropdown(["18-24","25-29","30-34","35-39","40-44","45-49","50-54","55-59","60-64","65-69","70-74","75-79","80 and older"],label="Age (in years)", type = "index")
Education = gr.Dropdown(["Never attended school or only kindergarten","Grades 1 through 8 (Elementary)","Grades 9 through 11 (Some high school)","Grade 12 or GED (High school graduate)","College 1 year to 3 years (Some college or technical school)","College 4 years or more (College graduated)"],label="Education Level", type = "index")
Income = gr.Dropdown(["Less than $10,000","Less than $16,250","Less than $22,500","Less than $28,750","Less than $35,000","Less than $48,500","Less than $61,500","$75,000 or more"],label="Income Level", type = "index")
submit_btn = gr.Button("Analyze")
with gr.Column(visible=True) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
submit_btn.click(
main_func,
[HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income],
[label,local_plot], api_name="Diabetes_Predictor"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([["No","No","No",23,"No","No","No","Yes","Yes","Yes","No","Yes","No",1,2,4,"No","Female","65-69","College 4 years or more (College graduated)","$75,000 or more"], ["Yes","Yes","Yes",32,"Yes","Yes","Yes","No","No","No","Yes","No","Yes",5,15,20,"Yes","Male","50-54","Grade 12 or GED (High school graduate)","Less than $35,000"]], [HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income], [label,local_plot], main_func, cache_examples=True)
demo.launch()