Spaces:
Runtime error
Runtime error
import gradio as gr | |
import datasets as ds | |
import pandas as pd | |
import numpy as np | |
from sklearn.ensemble import RandomForestClassifier | |
from lime.lime_tabular import LimeTabularExplainer | |
wines = ds.load_dataset("katossky/wine-recognition", split='train') | |
wines = wines.to_pandas() | |
wines.columns = wines.columns.str.strip() | |
predictor = RandomForestClassifier( | |
n_estimators=1000, max_depth=5, n_jobs=4, | |
random_state=44 # for reproducibility | |
) | |
predictor.fit( wines.drop('label', axis=1), wines['label'] ) | |
def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma): | |
instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1) | |
instance_np = instance_pd.to_numpy().squeeze() | |
explainer = LimeTabularExplainer( | |
training_data = wines.drop('label', axis=1), #.to_numpy(), | |
feature_names = wines.columns[1:].to_list(), | |
discretize_continuous = False, kernel_width=sigma | |
) | |
explanation = explainer.explain_instance( | |
instance_np, | |
predictor.predict_proba, #, | |
top_labels=3, | |
num_features=5 | |
) | |
predictions = predictor.predict_proba(instance_pd)[0] | |
label = np.argmax(predictions) | |
confidences = {i: predictions[i] for i in range(3)} | |
return ( | |
confidences, | |
explanation.as_pyplot_figure(label=label) | |
) | |
sigma_default = 0.75*(wines.shape[1]-1)**0.5 | |
sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ') | |
instance_complete = wines.sample(1) | |
instance_part_1 = gr.Dataframe( | |
label = "Chemical properties of the wine", | |
headers = wines.columns[1:6].to_list(), | |
row_count = (1,"fixed"), | |
col_count = (5, "fixed"), | |
datatype = "number", | |
value = instance_complete.iloc[:,1:6].values.tolist() | |
) | |
instance_part_2 = gr.Dataframe( | |
label = "", | |
show_label = False, # does not work | |
headers = wines.columns[6:10].to_list(), | |
row_count = (1,"fixed"), | |
col_count = (4, "fixed"), | |
datatype = "number", | |
value = instance_complete.iloc[:,6:10].values.tolist() | |
) | |
instance_part_3 = gr.Dataframe( | |
label = "", | |
show_label = False, # does not work | |
headers = wines.columns[10:].to_list(), | |
row_count = (1,"fixed"), | |
col_count = (4, "fixed"), | |
datatype = "number", | |
value = instance_complete.iloc[:,10:].values.tolist() | |
) | |
demo = gr.Interface( | |
fn = plot_explanation, | |
inputs = [instance_part_1, instance_part_2, instance_part_3, sigma], | |
outputs = ["label", "plot"] | |
) | |
demo.launch() | |