import gradio as gr
import joblib
import numpy as np
import pandas as pd
import shap
from matplotlib import pyplot as plt
from utils.data_processor import DataProcessor
from utils.model_predictor import ModelPredictor
from utils.data import patients_data, key_to_display_name_and_value_conversion
from huggingface_hub import hf_hub_download
import joblib
model = joblib.load(
hf_hub_download("Proddis/pbc_complication_model", "RandomForestClassifier_trained_pipeline.joblib")
)
categorical_names = joblib.load('resources/categorical_names.pkl')
target_labels = joblib.load('resources/target_labels.pkl')
selected_features = []
shap_explainer = shap.TreeExplainer(model.named_steps['RandomForestClassifier'])
data_processor = DataProcessor(model, categorical_names, selected_features)
predictor = ModelPredictor(model)
labels_map = {0: "Transplant/Death", 1: "Survive"}
plot_path = "shap_waterfall_plot.png"
def select_and_predict(patient_selection):
# Assuming 'patients_data' is a dict with patient profiles
user_input = patients_data[patient_selection] # Simulating user input
user_input_df = pd.DataFrame([user_input]) # Convert dict to DataFrame for processing
# Process the input and predict
prediction, probabilities = predictor.predict(user_input_df)
# preprocess input for shap values
preprocessed_input = data_processor.shap_and_eli5_custom_format(user_input_df)
user_input_items = "".join([
f"
{display_name}: {converted_value}"
for key, value in user_input.items()
for display_name, converted_value in [data_processor.convert_value(key, value)]
])
user_input_display = f""
# Generate features list dynamically
features = [key_to_display_name_and_value_conversion.get(key, (key.replace('_', ' ').title(), None))[0] for key in user_input.keys()]
label = labels_map.get(int(np.argmax(probabilities)))
# map data only for display reasons for shap value waterfall plot
mapped_row = data_processor.apply_mapping_to_row(user_input_df.iloc[0])
# SHAP Explanation
shap_values = shap_explainer.shap_values(preprocessed_input)
shap_explanation = shap.Explanation(values=shap_values[0][0, :],
base_values=shap_explainer.expected_value[0],
data=mapped_row,
feature_names=features)
# Generate and save the SHAP waterfall plot
shap.waterfall_plot(shap_explanation, max_display=len(user_input_df.columns), show=False)
fig = plt.gcf()
fig.set_size_inches(12, 8)
# fig.suptitle(f'Prediction: {label}', fontsize=20, y=1.05)
plt.savefig(plot_path, bbox_inches='tight')
plt.close(fig)
# Prepare the output
proba_df = pd.DataFrame(probabilities, columns=labels_map.values())
proba_df = proba_df.applymap(lambda x: f"{x*100:.1f}%")
proba_html = proba_df.to_html(classes='table table-striped', header="true", index=False)
prediction_html = f"Prediction: {label}
Probabilities: {proba_html}
"
return user_input_display, prediction_html, plot_path
with gr.Blocks() as app:
# title
with gr.Row():
gr.Markdown("# Risk of Disease Complication in Biliary Cirrhosis Patients") # Using Markdown for the page title
# select box and button
with gr.Row():
with gr.Column(scale=2): # Try adjusting scale here
dropdown = gr.Dropdown(list(patients_data.keys()), label="Select Patient Profile")
btn = gr.Button("Predict")
gr.Column([], scale=2)
gr.Column([], scale=2)
# input data and results
with gr.Row():
with gr.Column(scale=1): # Try adjusting scale here
user_input_output = gr.HTML()
with gr.Column(scale=1):
prediction_html = gr.HTML()
with gr.Column(scale=2): # Try adjusting scale here
gr.Markdown("# Risk Factors of Disease Complication")
output_image = gr.Image(show_share_button=False)
gr.Markdown(
"Left arrows show what features tipping the scales to 'Survive', and right arrows show what features leaning towards 'Transplant/Death'.")
btn.click(fn=select_and_predict, inputs=dropdown, outputs=[user_input_output, prediction_html, output_image])
app.launch()