import pickle import pandas as pd import shap from shap.plots._force_matplotlib import draw_additive_plot import gradio as gr import numpy as np import matplotlib.pyplot as plt # load the model from disk loaded_model = pickle.load(open("AIDS.pkl", 'rb')) # Setup SHAP explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS. # Create the main function for server def main_func(time, trt, age, wtkg, hemo, homo, drugs, karnof, oprior, z30, zprior, preanti, race, gender, str2, strat, symptom, treat, offtrt, cd40, cd420, cd80, cd820): new_row = pd.DataFrame.from_dict({'time':time, 'trt':trt, 'age':age, 'wtkg':wtkg, 'hemo':hemo, 'homo':homo, 'drugs':drugs, 'karnof':karnof, 'oprior':oprior, 'z30':z30, 'zprior':zprior, 'preanti':preanti, 'race':race, 'gender':gender, 'str2':str2, 'strat':strat, 'symptom':symptom, 'treat':treat, 'offtrt':offtrt, 'cd40':cd40, 'cd420':cd420, 'cd80':cd80, 'cd820':cd820}, orient = 'index').transpose() prob = loaded_model.predict_proba(new_row) shap_values = explainer(new_row) # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False) # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False) plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False) plt.tight_layout() local_plot = plt.gcf() plt.close() return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot # Create the UI title = "**AIDS Predictor and Interpreter** 🧪" description1 = """This app takes information from subjects and predicts their likelihood of contracting AIDS. Do not use for medical diagnosis.""" description2 = """ To use the app, click on one of the examples, or adjust the values of the factors, and click on "Analyze". 🤞 """ with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") gr.Markdown(description2) gr.Markdown("""---""") with gr.Row(): time = gr.Slider(label="Time to failure or censoring", minimum=14, maximum=1231, value=600, step=100) with gr.Row(): trt = gr.Radio(["ZDV Only","ZDV + ddl","ZDV + Zal", "ddl only"],label="Treatment", type = "index") with gr.Row(): age = gr.Slider(label="Age", minimum=12, maximum=70, value=40, step=1) with gr.Row(): wtkg = gr.Slider(label="Weight in kg", minimum=30, maximum=160, value=70, step=1) with gr.Row(): hemo = gr.Radio(["No","Yes"],label="History of Hemophelia", type = "index") homo = gr.Radio(["No","Yes"],label="Homosexual Activity", type="index") drugs = gr.Radio(["No","Yes"],label="Intravenous drug use", type="index") with gr.Row(): karnof = gr.Slider(label="Karnof Score (0-100)", minimum=0, maximum=100, value=90, step=5) with gr.Row(): oprior = gr.Radio(["No","Yes"],label="Non-ZDV anti-retroviral therapy pre-175", type = "index") z30 = gr.Radio(["No","Yes"],label="ZDV in the 30 days prior to 175", type = "index") zprior = gr.Radio(["No","Yes"],label="ZDV prior to 175", type = "index") with gr.Row(): preanti = gr.Slider(label="Number of days pre-175 anti-retroviral therapy", minimum=0, maximum=2851, value=1000, step=1) with gr.Row(): race = gr.Radio(["White","Non-White"],label="Race", type="index") gender = gr.Radio(["Female", "Male"],label="Gender", type = "index") str2 = gr.Radio(["Naive", "Experienced"],label="Anti-Retroviral History", type = "index") with gr.Row(): strat = gr.Radio(["Anti-retroviral naive",">1 but <=52 weeks of prior anti-retroviral therapy",">52 weeks"],label="Anti-retroviral history stratification",type="index") with gr.Row(): symptom = gr.Dropdown(label="Symptoms",choices=["Asymptomatic","Symptomatic"],type="index") treat = gr.Dropdown(label="Treatment",choices=["ZDV only", "others"],type="index") offtrt = gr.Dropdown(label="Off treatment indicator before 96 +/- 5 weeks",choices=["No", "Yes"],type="index") with gr.Row(): cd40 = gr.Slider(label="CD4 at baseline", minimum=0, maximum=1119, value=500, step=50) cd420 = gr.Slider(label="CD4 at 20 +/-5 weeks", minimum=49, maximum=1119, value=600, step=50) cd80 = gr.Slider(label="CD8 at baseline", minimum=40, maximum=5011, value=2000, step=250) cd820 = gr.Slider(label="CD8 at 20 +/-5 weeks", minimum=124, maximum=6035, value=3000, step=250) submit_btn = gr.Button("Analyze") with gr.Column(visible=True) as output_col: label = gr.Label(label = "Predicted Label") local_plot = gr.Plot(label = 'Shap:') submit_btn.click( main_func, [time, trt, age, wtkg, hemo, homo, drugs, karnof, oprior, z30, zprior, preanti, race, gender, str2, strat, symptom, treat, offtrt, cd40, cd420, cd80, cd820], [label,local_plot], api_name="AIDS_Predictor" ) gr.Markdown("### Click on any of the examples below to see how it works:") gr.Examples( [[550, "ZDV Only", 20, 90, "Yes", "No", "Yes", 30, "No", "No", "Yes", 1000, "White", "Male", "Experienced", ">52 weeks", "Symptomatic", "ZDV only", "Yes", 500, 450, 100, 250], [875, "ZDV + ddl", 45, 70, "No", "Yes", "Yes", 50, "Yes", "Yes", "No", 1650, "Non-White", "Female", "Naive", "Anti-retroviral naive", "Asymptomatic", "others", "No", 1000, 250, 4000, 1200]], [time, trt, age, wtkg, hemo, homo, drugs, karnof, oprior, z30, zprior, preanti, race, gender, str2, strat, symptom, treat, offtrt, cd40, cd420, cd80, cd820], [label,local_plot], main_func, cache_examples = True) demo.launch()