import gradio as gr import pandas as pd import joblib PREPROCESSOR_PATH = "preprocessor_v1.0.joblib" MODEL_PATH = "calibrated_random_forest_model_updated_v1.1.joblib" preprocessor = joblib.load(PREPROCESSOR_PATH) model = joblib.load(MODEL_PATH) def predict_patient_outcome( age: int, sex: str, race: str, etiology_cirrosis: str, hepatorenal_syndrome: str, omeprazole: str, spironolactone: str, furosemide: str, propanolol: str, dialisis: str, portal_vein_thrombosis: str, ascitis: str, hepatocellular_carcinoma: str, albumin: float, total_bilirrubin: float, direct_bilirrubina: float, inr: float, creatinine: float, platelets: float, ast: float, alt: float, hemoglobin: float, hematocrit: float, leucocytes: float, sodium: float, potassium: float, varices: str, red_wale_marks: str, rupture_point: str, active_bleeding: str, therapy: str, terlipressin_dose: float, time_to_endoscophy_hours: float, rebleeding: str ): input_data = { "age": age, "sex": sex, "race": race, "etiology_cirrosis": etiology_cirrosis, "hepatorenal_syndrome": hepatorenal_syndrome, "omeprazole": omeprazole, "spironolactone": spironolactone, "furosemide": furosemide, "propanolol": propanolol, "dialisis": dialisis, "portal_vein_thrombosis": portal_vein_thrombosis, "ascitis": ascitis, "hepatocellular_carcinoma": hepatocellular_carcinoma, "albumin": albumin, "total_bilirrubin": total_bilirrubin, "direct_bilirrubina": direct_bilirrubina, "inr": inr, "creatinine": creatinine, "platelets": platelets, "ast": ast, "alt": alt, "hemoglobin": hemoglobin, "hematocrit": hematocrit, "leucocytes": leucocytes, "sodium": sodium, "potassium": potassium, "varices": varices, "red_wale_marks": red_wale_marks, "rupture_point": rupture_point, "active_bleeding": active_bleeding, "therapy": therapy, "terlipressin_dose": terlipressin_dose, "time-to-endoscophy_hours": time_to_endoscophy_hours, "rebleeding": rebleeding } df = pd.DataFrame([input_data]) processed_data = preprocessor.transform(df) prediction = model.predict(processed_data)[0] probability = model.predict_proba(processed_data)[:, 1][0] return prediction, probability ############################### # GRADIO BLOCKS INTERFACE ############################### with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # EVB PROGNOSIS: 1-Year Mortality Risk Calculator **Advanced Machine Learning Model for Predicting Post-Bleeding Survival in Cirrhotic Patients** """ ) # TAB 1: General Info with gr.Tab("General Info"): with gr.Row(): age = gr.Slider(minimum=18, maximum=100, step=1, label="Age", value=50) sex = gr.Dropdown(choices=["male", "female"], label="Sex", value="male") race = gr.Dropdown(choices=["white", "black", "asian", "other"], label="Race", value="white") etiology_cirrosis = gr.Dropdown( choices=["alcohol", "hcv", "alcohol+hcv", "other"], label="Etiology Cirrhosis", value="alcohol" ) with gr.Row(): hepatorenal_syndrome = gr.Dropdown(choices=["yes", "no"], label="Hepatorenal Syndrome", value="no") omeprazole = gr.Dropdown(choices=["yes", "no"], label="Omeprazole", value="no") spironolactone = gr.Dropdown(choices=["yes", "no"], label="Spironolactone", value="yes") furosemide = gr.Dropdown(choices=["yes", "no"], label="Furosemide", value="yes") propanolol = gr.Dropdown(choices=["yes", "no"], label="Propanolol", value="no") dialisis = gr.Dropdown(choices=["yes", "no"], label="Dialysis", value="no") # TAB 2: Clinical Status with gr.Tab("Clinical Status"): with gr.Row(): portal_vein_thrombosis = gr.Dropdown(choices=["yes", "no"], label="Portal Vein Thrombosis", value="no") ascitis = gr.Dropdown(choices=["yes", "no"], label="Ascites", value="yes") hepatocellular_carcinoma = gr.Dropdown(choices=["yes", "no"], label="Hepatocellular Carcinoma", value="no") varices = gr.Dropdown(choices=["yes", "no"], label="Varices", value="yes") red_wale_marks = gr.Dropdown(choices=["yes", "no"], label="Red Wale Marks", value="no") rupture_point = gr.Dropdown(choices=["yes", "no"], label="Rupture Point", value="no") active_bleeding = gr.Dropdown(choices=["yes", "no"], label="Active Bleeding", value="no") rebleeding = gr.Dropdown(choices=["yes", "no"], label="Rebleeding", value="no") therapy = gr.Dropdown( choices=["Banding", "Sclerotherapy", "No therapy"], label="Therapy", value="Banding" ) terlipressin_dose = gr.Slider(minimum=0, maximum=20, step=1, label="Terlipressin Dose", value=2) time_to_endoscophy_hours = gr.Slider(minimum=0, maximum=48, step=1, label="Time to Endoscopy (Hours)", value=12) # TAB 3: Laboratory Values with gr.Tab("Laboratory Values"): with gr.Row(): albumin = gr.Slider(minimum=1, maximum=5, step=0.1, label="Albumin", value=3.5) total_bilirrubin = gr.Slider(minimum=0.1, maximum=30, step=0.1, label="Total Bilirubin", value=2.0) direct_bilirrubina = gr.Slider(minimum=0.1, maximum=10, step=0.1, label="Direct Bilirubin", value=0.5) inr = gr.Slider(minimum=0.5, maximum=5, step=0.1, label="INR", value=1.2) creatinine = gr.Slider(minimum=0.1, maximum=10, step=0.1, label="Creatinine", value=1.0) with gr.Row(): platelets = gr.Slider(minimum=10, maximum=500, step=1, label="Platelets", value=150) ast = gr.Slider(minimum=10, maximum=500, step=1, label="AST", value=35) alt = gr.Slider(minimum=10, maximum=500, step=1, label="ALT", value=25) hemoglobin = gr.Slider(minimum=5, maximum=20, step=0.1, label="Hemoglobin", value=13) hematocrit = gr.Slider(minimum=15, maximum=60, step=0.1, label="Hematocrit", value=40) with gr.Row(): leucocytes = gr.Slider(minimum=1, maximum=50, step=0.1, label="Leukocytes (WBC)", value=6) sodium = gr.Slider(minimum=120, maximum=160, step=1, label="Sodium", value=140) potassium = gr.Slider(minimum=2, maximum=6, step=0.1, label="Potassium", value=4) # TAB 4: Messages & Info with gr.Tab("Messages & Info"): gr.Markdown( """ ### General Notes and Disclaimers This tool is intended for research purposes only and therefore should not be used to provide medical advice, consultation, diagnosis, or treatment.Neither the authors nor the hospital guarantee the accuracy of its calculations for any particular patient.These services provide no warranties, express or implied and shall not be liable for any direct, consequential, lost profits, or other damages incurred by the user of this information tool. The model architecture consists of a Random Forest Classifier with isotonic calibration, carefully tuned to predict 1-year mortality after esophageal variceal bleeding. Let me break down the key components and their significance: The Random Forest baseline configuration reflects a robust approach to handling complex medical data. The model uses 100 decision trees (n_estimators: 100) with bootstrapped sampling (bootstrap: True), meaning each tree is trained on a random subset of the data, which helps prevent overfitting. The trees use the Gini impurity criterion (criterion: 'gini') to evaluate the quality of splits, which is particularly effective for binary classification problems like mortality prediction. The feature selection strategy employs the square root of the total number of features (max_features: 'sqrt'), a common practice that helps maintain model stability while capturing important variable interactions. The trees are allowed to grow to their full depth (max_depth: None), but are constrained by requiring at least 2 samples to create a split (min_samples_split: 2) and 1 sample in each leaf node (min_samples_leaf: 1). This configuration balances model complexity with predictive power. The calibration layer is particularly important for medical applications where probability estimates need to be reliable. The model uses isotonic regression (method: 'isotonic') with 5-fold cross-validation (cv: 5). Isotonic calibration is especially suitable for this case because: It makes minimal assumptions about the form of the calibration curve It can handle the non-linear relationships often present in medical data It provides well-calibrated probabilities across the entire range of predictions The ensemble approach (ensemble: True) in the calibration phase means that predictions are averaged across multiple calibrated models, each trained on different cross-validation folds. This helps reduce variance and provides more stable probability estimates. For reproducibility, the model uses a fixed random state (random_state: 42). Cost-complexity pruning is not applied (ccp_alpha: 0.0), allowing the model to capture complex patterns in the data while relying on the ensemble nature of random forests to prevent overfitting. """ ) # Ensure the order of these inputs matches the function signature all_inputs = [ age, sex, race, etiology_cirrosis, hepatorenal_syndrome, omeprazole, spironolactone, furosemide, propanolol, dialisis, portal_vein_thrombosis, ascitis, hepatocellular_carcinoma, albumin, total_bilirrubin, direct_bilirrubina, inr, creatinine, platelets, ast, alt, hemoglobin, hematocrit, leucocytes, sodium, potassium, varices, red_wale_marks, rupture_point, active_bleeding, therapy, terlipressin_dose, time_to_endoscophy_hours, rebleeding ] # Prediction Button & Outputs with gr.Row(): predict_btn = gr.Button("Predict Outcome", variant="primary") with gr.Row(): pred_label = gr.Textbox(label="Prediction (Class)") pred_prob = gr.Textbox(label="Probability (Class=1)") predict_btn.click( fn=predict_patient_outcome, inputs=all_inputs, outputs=[pred_label, pred_prob] ) # For Gradio 4.31.4, we can't specify title/description in .launch() if __name__ == "__main__": demo.launch()