import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.ensemble import RandomForestClassifier, VotingClassifier from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import GaussianNB def choose_model(model): if model == "Logistic Regression": return LogisticRegression(max_iter=1000, random_state=123) elif model == "Random Forest": return RandomForestClassifier(n_estimators=100, random_state=123) elif model == "Gaussian Naive Bayes": return GaussianNB() else: raise ValueError("Model is not supported.") def get_proba_plots( model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight ): clf1 = choose_model(model_1) clf2 = choose_model(model_2) clf3 = choose_model(model_3) X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]]) y = np.array([1, 1, 2, 2]) eclf = VotingClassifier( estimators=[("clf1", clf1), ("clf2", clf2), ("clf3", clf3)], voting="soft", weights=[model_1_weight, model_2_weight, model_3_weight], ) # predict class probabilities for all classifiers probas = [c.fit(X, y).predict_proba(X) for c in (clf1, clf2, clf3, eclf)] # get class probabilities for the first sample in the dataset class1_1 = [pr[0, 0] for pr in probas] class2_1 = [pr[0, 1] for pr in probas] # plotting N = 4 # number of groups ind = np.arange(N) # group positions width = 0.35 # bar width fig, ax = plt.subplots() # bars for classifier 1-3 p1 = ax.bar( ind, np.hstack(([class1_1[:-1], [0]])), width, color="green", edgecolor="k" ) p2 = ax.bar( ind + width, np.hstack(([class2_1[:-1], [0]])), width, color="lightgreen", edgecolor="k", ) # bars for VotingClassifier ax.bar(ind, [0, 0, 0, class1_1[-1]], width, color="blue", edgecolor="k") ax.bar( ind + width, [0, 0, 0, class2_1[-1]], width, color="steelblue", edgecolor="k" ) # plot annotations plt.axvline(2.8, color="k", linestyle="dashed") ax.set_xticks(ind + width) ax.set_xticklabels( [ f"{model_2}\nweight {model_1_weight}", f"{model_1}\nweight {model_2_weight}", f"{model_3}\nweight {model_3_weight}", "VotingClassifier\n(average probabilities)", ], rotation=40, ha="right", ) plt.ylim([0, 1]) plt.title("Class probabilities for sample 1 by different classifiers") plt.legend([p1[0], p2[0]], ["class 1", "class 2"], loc="upper left") plt.tight_layout() plt.show() return fig with gr.Blocks() as demo: with gr.Row(): model_1 = gr.Dropdown( [ "Logistic Regression", "Random Forest", "Gaussian Naive Bayes", ], label="Model 1", value="Logistic Regression", ) model_2 = gr.Dropdown( [ "Logistic Regression", "Random Forest", "Gaussian Naive Bayes", ], label="Model 2", value="Random Forest", ) model_3 = gr.Dropdown( [ "Logistic Regression", "Random Forest", "Gaussian Naive Bayes", ], label="Model 3", value="Gaussian Naive Bayes", ) with gr.Row(): model_1_weight = gr.Slider( minimum=1, maximum=10, value=1, label="Model 1 Weight", step=1 ) model_2_weight = gr.Slider( minimum=1, maximum=10, value=1, label="Model 2 Weight", step=1 ) model_3_weight = gr.Slider( minimum=1, maximum=10, value=5, label="Model 3 Weight", step=1 ) proba_plots = gr.Plot() model_1.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) model_2.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) model_3.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) model_1_weight.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) model_2_weight.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) model_3_weight.change( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) demo.load( get_proba_plots, [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight], proba_plots, queue=False, ) if __name__ == "__main__": demo.launch()