File size: 5,812 Bytes
9cb2f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f1b7f4
9cb2f0f
 
 
 
 
 
 
 
 
3f1b7f4
9cb2f0f
 
3f1b7f4
9cb2f0f
 
 
22cfb99
3f1b7f4
 
 
 
 
 
 
a464d14
3f1b7f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb2f0f
3f1b7f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cb2f0f
 
 
 
 
 
4e4912d
3f1b7f4
9cb2f0f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt

# load the model from disk
loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))

# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.

# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall):
    new_row = pd.DataFrame.from_dict({'age':age,'sex':sex,
              'cp':cp,'trtbps':trtbps,'chol':chol,
              'fbs':fbs, 'restecg':restecg,'thalachh':thalachh,'exng':exng,
                                     'oldpeak':oldpeak,'slp':slp,'caa':caa,'thall':thall}, 
                                     orient = 'index').transpose()
    
    prob = loaded_model.predict_proba(new_row)
    
    shap_values = explainer(new_row)
    # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
    # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
    plot = shap.plots.bar(shap_values[0], max_display=8, order=shap.Explanation.abs, show_data='auto', show=False)

    plt.tight_layout()
    local_plot = plt.gcf()
    plt.close()
    
    return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot

# Create the UI
title = "**Heart Attack Predictor & Interpreter** 🪐"
description1 = """This app takes info from subjects and predicts their heart attack likelihood. Do not use these results for an actual medical diagnosis."""

description2 = """
To use the app, simply adjust the inputs and click the "Analyze" button. You can also click one of the examples below to see how it's done!
""" 

with gr.Blocks(title=title) as demo:

    with gr.Row():
        with gr.Column():
            gr.Markdown(f"# {title}")
            gr.Markdown(f"## How does it work?")
            gr.Markdown(description1)
            gr.Markdown("""---""")
            gr.Markdown(description2)

    gr.Markdown("""---""")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown(f"## Edit the Inputs Below:")
            gr.Markdown("""---""")

            with gr.Row():
                age = gr.Number(label="Age", info="How old are you?", value=40)
                # sex = gr.Radio(["Male", "Female"], label = "What Gender are you?", type = "index")
                sex = gr.Radio(["Male", "Female"], label="Sex", info="What gender are you?", type="index")
                # sex = gr.Radio(choices=["Male", "Female"])
            
            cp = gr.Radio(["Typical Angina", "Atypical Angina", "Non-anginal Pain", "Asymptomatic"], label="Chest Pain", info="What kind of chest pain do you have?", type="index")
            # cp = gr.Slider(label="Chest Pain Type", minimum=1, maximum=5, value=4, step=1)
             # trtbps = gr.Slider(label="Resting blood pressure (in mm Hg)", minimum=1, maximum=200, value=4, step=1) 
            trtbps = gr.Number(label="trtbps", value=100) 
            chol = gr.Number(label="chol", value=70) 
            fbs = gr.Radio(["False", "True"], label="fbs",  info="Is your fasting blood sugar > 120 mg/dl?" , type="index")

            # restecg = gr.Slider(label="Resting ECG Score", minimum=1, maximum=5, value=4, step=1)
            restecg = gr.Dropdown(["Normal", "Having ST-T wave abnormality", "Showing probable or definite left ventricular hypertrophy by Estes' criteria"], label="rest_ecg", type="index")
            thalachh = gr.Slider(label="thalach Score", minimum=1, maximum=205, value=4, step=1)
            exng = gr.Radio(["No", "Yes"], label="Exercise Induced Angina", type="index")
            oldpeak = gr.Slider(label="Oldpeak Score", minimum=1, maximum=10, value=4, step=1)
            slp = gr.Slider(label="Slp Score", minimum=1, maximum=5, value=4, step=1)
            caa = gr.Slider(label="Number of Major Vessels", minimum=1, maximum=3, value=3, step=1)
            thall = gr.Slider(label="Thall Score", minimum=1, maximum=5, value=4, step=1)

            
            
            

        with gr.Column():
            gr.Markdown(f"## Output:")
            gr.Markdown("""---""")
            with gr.Column(visible=True) as output_col:
                label = gr.Label(label = "Predicted Label")
                local_plot = gr.Plot(label = 'Shap:')

            gr.Markdown(f"## Examples:")
            gr.Markdown("""---""")
            gr.Markdown("### Click on any of the examples below to see how it works:")
            gr.Examples([[24,"Male","Typical Angina",4,5,"True","Normal",4,"No",5,1,2,3], [24,"Female","Asymptomatic",4,5,"False","Normal",2,"Yes",1,1,2,3]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall], [label,local_plot], main_func, cache_examples=True)

        
    submit_btn = gr.Button("Analyze", variant="primary")


    gr.Markdown("""---""")
    gr.Markdown(f"## Data Dictionary:")
    gr.Markdown("""
    
Age : Age of the patient
Sex : Sex of the patient
trtbps : resting blood pressure (in mm Hg)
chol : cholestoral in mg/dl fetched via BMI sensor
fbs : (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
rest_ecg : resting electrocardiographic results
    Value 0: normal
    Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
    Value 2: showing probable or definite left ventricular hypertrophy by Estes' criteria
    
thalach : maximum heart rate achieved
target : 0 = less chance of heart attack 1= more chance of heart attack""")
    

    submit_btn.click(
        main_func,
        [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall],
        [label,local_plot], api_name="Heart_Predictor"
    )

    
demo.launch()