cmpatino's picture
Intital commit
307e9da
raw
history blame
5.2 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
def choose_model(model):
if model == "Logistic Regression":
return LogisticRegression(max_iter=1000, random_state=123)
elif model == "Random Forest":
return RandomForestClassifier(n_estimators=100, random_state=123)
elif model == "Gaussian Naive Bayes":
return GaussianNB()
else:
raise ValueError("Model is not supported.")
def get_proba_plots(
model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight
):
clf1 = choose_model(model_1)
clf2 = choose_model(model_2)
clf3 = choose_model(model_3)
X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]])
y = np.array([1, 1, 2, 2])
eclf = VotingClassifier(
estimators=[("clf1", clf1), ("clf2", clf2), ("clf3", clf3)],
voting="soft",
weights=[model_1_weight, model_2_weight, model_3_weight],
)
# predict class probabilities for all classifiers
probas = [c.fit(X, y).predict_proba(X) for c in (clf1, clf2, clf3, eclf)]
# get class probabilities for the first sample in the dataset
class1_1 = [pr[0, 0] for pr in probas]
class2_1 = [pr[0, 1] for pr in probas]
# plotting
N = 4 # number of groups
ind = np.arange(N) # group positions
width = 0.35 # bar width
fig, ax = plt.subplots()
# bars for classifier 1-3
p1 = ax.bar(
ind, np.hstack(([class1_1[:-1], [0]])), width, color="green", edgecolor="k"
)
p2 = ax.bar(
ind + width,
np.hstack(([class2_1[:-1], [0]])),
width,
color="lightgreen",
edgecolor="k",
)
# bars for VotingClassifier
ax.bar(ind, [0, 0, 0, class1_1[-1]], width, color="blue", edgecolor="k")
ax.bar(
ind + width, [0, 0, 0, class2_1[-1]], width, color="steelblue", edgecolor="k"
)
# plot annotations
plt.axvline(2.8, color="k", linestyle="dashed")
ax.set_xticks(ind + width)
ax.set_xticklabels(
[
f"{model_2}\nweight {model_1_weight}",
f"{model_1}\nweight {model_2_weight}",
f"{model_3}\nweight {model_3_weight}",
"VotingClassifier\n(average probabilities)",
],
rotation=40,
ha="right",
)
plt.ylim([0, 1])
plt.title("Class probabilities for sample 1 by different classifiers")
plt.legend([p1[0], p2[0]], ["class 1", "class 2"], loc="upper left")
plt.tight_layout()
plt.show()
return fig
with gr.Blocks() as demo:
with gr.Row():
model_1 = gr.Dropdown(
[
"Logistic Regression",
"Random Forest",
"Gaussian Naive Bayes",
],
label="Model 1",
value="Logistic Regression",
)
model_2 = gr.Dropdown(
[
"Logistic Regression",
"Random Forest",
"Gaussian Naive Bayes",
],
label="Model 2",
value="Random Forest",
)
model_3 = gr.Dropdown(
[
"Logistic Regression",
"Random Forest",
"Gaussian Naive Bayes",
],
label="Model 3",
value="Gaussian Naive Bayes",
)
with gr.Row():
model_1_weight = gr.Slider(
minimum=1, maximum=10, value=1, label="Model 1 Weight", step=1
)
model_2_weight = gr.Slider(
minimum=1, maximum=10, value=1, label="Model 2 Weight", step=1
)
model_3_weight = gr.Slider(
minimum=1, maximum=10, value=5, label="Model 3 Weight", step=1
)
proba_plots = gr.Plot()
model_1.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
model_2.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
model_3.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
model_1_weight.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
model_2_weight.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
model_3_weight.change(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
demo.load(
get_proba_plots,
[model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
proba_plots,
queue=False,
)
if __name__ == "__main__":
demo.launch()