import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
theme = gr.themes.Default(primary_hue="blue").set(
# load the model from disk
loaded_model = pickle.load(open("heart_xgbV2.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
gender_dict = {"Male":0,"Female":1}
cp_dict = {"Typical Angina":0, "Atypical Angina":1, "Non-Anginal":2, "Asymptomatic":3}
fbs_dict = {"Yes":1,"No":0}
exng_dict = {"Yes":1,"No":0}
restecg_dict = {"Normal":0, "Having ST-T abnormality":1, "Showing probable or definite left ventricular hypertrophy by Estes' Criteria":2}
thall_dict = {"Fixed Defect":1, "Normal Blood Flow":2, "Reversible Defect":3}
slp_dict = {"Upsloping":1, "Flat":2, "Downsloping":3}
# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg,thalachh,exng,oldpeak,slp,caa,thall):
new_row = pd.DataFrame.from_dict({'age':age,'sex':gender_dict[sex],
'fbs':fbs_dict[fbs], 'restecg':restecg_dict[restecg], 'thalachh':thalachh, 'exng':exng_dict[exng],
orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
# plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
# plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
plot =[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
local_plot = plt.gcf()
return {"Lower Chance of a Heart Attack": float(prob[0][0]), "Higher Chance of a Heart Attack": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "**Heart Attack Predictor & Interpreter** 🪐"
description1 = "This app takes info from subjects and predicts their heart attack likelihood."
description_notmedical="**Do not use for medical diagnosis.**"
description2 = "**Fill all the options** or no result will be generated!!!**"
description3 = "To use the app, please fill all the options, and click on Analyze. 🤞"
descriptionExamples = "If you would like to see how the model works, please scroll down and try one of the examples!"
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown(" **Heart Attack Predictor & Interpreter** 🪐")
gr.Markdown(" **Do not use for medical diagnosis.**")
gr.Markdown(" If you would like to see how the model works, please scroll down and try one of the examples!")
gr.Markdown(" This app takes info from subjects and predicts their heart attack likelihood.")
gr.Markdown(" To use the app, please fill in all the options, and click on Analyze. 🤞")
gr.Markdown(" **Fill all the options or no result will be generated!!!**")
with gr.Row():
with gr.Column():
age = gr.Number(label="What is your age?", value=40)
with gr.Column():
slp = gr.Dropdown(["Upsloping", "Flat", "Downsloping"], label="What was the slope of the peak exercise ST segment?")
with gr.Row():
with gr.Column():
sex = gr.Radio(["Female", "Male"], label = "What is your sex?")
cp = gr.Radio(["Typical Angina", "Atypical Angina", "Non-Anginal", "Asymptomatic"], label = "What kind of chest pain is it?")
with gr.Column():
restecg = gr.Radio(["Normal", "Having ST-T abnormality", "Showing probable or definite left ventricular hypertrophy by Estes' Criteria"],
label = "What is your resting ECG result?")
with gr.Row():
with gr.Column():
fbs = gr.Radio(["Yes", "No"], label = "Is your fasting Blood Sugar >120 mg/dl?")
with gr.Column():
exng = gr.Radio(["Yes", "No"], label = "Do you have Exercise Induced Angina?")
with gr.Row():
with gr.Column():
caa = gr.Radio([1, 2, 3], label="How many vessels were colored by the fluoroscopy?")
with gr.Column():
thall = gr.Radio(["Fixed Defect", "Normal Blood Flow", "Reversible Defect"], label="What is your Thalassemia condition?")
with gr.Row():
with gr.Column():
trtbps = gr.Slider(label = "What is your resting blood Pressure (in mm Hg)?", minimum = 10, maximum = 250, value = 100, step = 1)
with gr.Column():
chol = gr.Slider(label = "What is your cholesterol in mg/dl (via BMI sensor)?", minimum = 30, maximum = 300, value = 180, step = 1)
with gr.Row():
with gr.Column():
oldpeak = gr.Slider(label = "What was the ST depression induced by exercise relative to rest?", minimum = 0, maximum = 6.2, step = 0.1)
with gr.Column():
thalachh = gr.Slider(label="What is your maximum heart rate?", minimum = 60, maximum = 250, value=100, step = 1)
with gr.Row():
submit_btn = gr.Button("Analyze")
##Do not need to touch
with gr.Column(visible=True) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
[age, sex, cp, trtbps, chol, fbs, restecg,thalachh,exng,oldpeak,slp,caa,thall],
[label,local_plot], api_name="Heart_Predictor"
gr.Examples([[24, "Male", "Typical Angina", 130, 150, "Yes", "Having ST-T abnormality",170, "Yes", 5.1, "Flat", 2, "Normal Blood Flow"],
[59, "Female", "Non-Anginal", 150, 170, "No", "Showing probable or definite left ventricular hypertrophy by Estes' Criteria",190, "No", 6, "Upsloping", 3, "Reversible Defect"]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall], [label,local_plot], main_func, cache_examples=True)