File size: 2,394 Bytes
65e0688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1da6f58
65e0688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import datasets as ds
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from lime.lime_tabular import LimeTabularExplainer

wines = ds.load_dataset("katossky/wine-recognition", split='train')
wines = wines.to_pandas()
wines.columns = wines.columns.str.strip()

predictor = RandomForestClassifier(
    n_estimators=1000, max_depth=5, n_jobs=4,
    random_state=44 # for reproducibility
)

predictor.fit( wines.drop('label', axis=1), wines['label'] )

def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma):
  instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1)
  instance_np = instance_pd.to_numpy().squeeze()
  explainer = LimeTabularExplainer(
    training_data = wines.drop('label', axis=1), #.to_numpy(),
    feature_names = wines.columns[1:].to_list(),
    discretize_continuous = False, kernel_width=sigma
  )
  explanation = explainer.explain_instance(
    instance_np,
    predictor.predict_proba, #,
    top_labels=3,
    num_features=5
  )
  predictions = predictor.predict_proba(instance_pd)[0]
  label = np.argmax(predictions)
  confidences = {i: predictions[i] for i in range(3)}
  return (
    confidences,
    explanation.as_pyplot_figure(label=label)
  )

sigma_default = 0.75*(wines.shape[1]-1)**0.5
sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ')

instance_complete = wines.sample(1)

instance_part_1 = gr.Dataframe(
  label = "Chemical properties of the wine",
  headers = wines.columns[1:6].to_list(),
  row_count = (1,"fixed"),
  col_count = (5, "fixed"),
  datatype = "number",
  value = instance_complete.iloc[:,1:6].values.tolist()
)

instance_part_2 = gr.Dataframe(
  label = "",
  show_label = False, # does not work
  headers = wines.columns[6:10].to_list(),
  row_count = (1,"fixed"),
  col_count = (4, "fixed"),
  datatype = "number",
  value = instance_complete.iloc[:,6:10].values.tolist()
)

instance_part_3 = gr.Dataframe(
  label = "",
  show_label = False, # does not work
  headers = wines.columns[10:].to_list(),
  row_count = (1,"fixed"),
  col_count = (4, "fixed"),
  datatype = "number",
  value = instance_complete.iloc[:,10:].values.tolist()
)

demo = gr.Interface(
  fn = plot_explanation,
  inputs = [instance_part_1, instance_part_2, instance_part_3, sigma],
  outputs = ["label", "plot"]
)
demo.launch()