File size: 3,587 Bytes
b1b3f23
d8edf39
 
fa636b5
 
8d20412
e3914b4
 
 
 
8d20412
8e67e6e
b590d13
e3914b4
 
 
 
 
 
 
 
 
 
aa3b50d
e3914b4
aa3b50d
a9b677d
8d20412
a9b677d
c38a7bb
e3914b4
 
73c8e91
 
e572b7a
73c8e91
6b38cd2
e3914b4
 
 
 
 
 
4365fdc
 
a9b677d
 
 
 
e3914b4
 
 
 
 
 
 
 
 
 
 
 
a9b677d
 
e3914b4
 
 
 
 
 
 
 
 
4790b80
decc6b2
e3914b4
 
 
 
73c8e91
fa636b5
e3914b4
a9b677d
 
e3914b4
429b311
 
 
 
6b38cd2
429b311
6b38cd2
971bf27
 
73c8e91
 
 
 
 
971bf27
4365fdc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from cgitb import enable
from pyexpat import model
from statistics import mode
import numpy as np
import gradio as gr

import argparse
import os
from os.path import exists, dirname
import sys
import json
import flask
from PIL import Image

parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)

from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif	


def get_image_data(inp_image):
    """Gets the image data and model."""
    image = get_dataset_by_name(inp_image, get_label=False)
    # print("image returned\n", image)
    model_and_data = process_imagenet_get_model(image)
    # print("model returned\n", model_and_data)
    return image, model_and_data


def segmentation_generation(input_image, c_width, n_top, n_gif_imgs):
    print("Inputs Received:", input_image, c_width, n_top, n_gif_imgs)    

    image, model_and_data = get_image_data(input_image)
    
    # Unpack datax
    xtest = model_and_data["xtest"]
    ytest = model_and_data["ytest"]
    segs = model_and_data["xtest_segs"]
    get_model = model_and_data["model"]
    label = model_and_data["label"]
    

    # if (image_name == 'imagenet_diego'):
    #     label = 156
    # elif (image_name == 'imagenet_french_bulldog'):
    #     label = 245

    # Unpack instance and segments
    instance = xtest[0]
    segments = segs[0]

    # Get wrapped model
    cur_model = get_model(instance, segments)

    # Get background data
    xtrain = get_xtrain(segments)

    prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
    # if image_name in ["imagenet_diego", "imagenet_french_bulldog"]:
    #     assert prediction == label, f"Prediction is {prediction} not {label}"

    # Compute explanation
    exp_init = BayesLocalExplanations(training_data=xtrain,
                                              data="image",
                                              kernel="lime",
                                              categorical_features=np.arange(xtrain.shape[1]),
                                              verbose=True)
    rout = exp_init.explain(classifier_f=cur_model,
                            data=np.ones_like(xtrain[0]),
                            label=int(prediction[0]),
                            cred_width=c_width,
                            focus_sample=False,
                            l2=False)

    # Create the gif of the explanation
    return create_gif(rout['blr'], input_image, segments, instance, prediction[0], n_gif_imgs, n_top)

if __name__ == "__main__":
    inp = gr.inputs.Image(label="Input Image (Or select an example)", type="pil")
    out = [gr.outputs.HTML(label="Output GIF"), gr.outputs.Textbox(label="Prediction")]

    iface = gr.Interface(
        segmentation_generation, 
        [
            inp,
            gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.01, default=0.01, label="cred_width", optional=False),
            gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
            gr.inputs.Slider(minimum=10, maximum=100, step=1, default=30, label="n_gif_images", optional=False), 
        ], 
        outputs=out, 
        examples=[["./data/diego.png", 0.01, 7, 50], 
                  ["./data/french_bulldog.jpg", 0.01, 5, 50],
                  ["./data/pepper.jpeg", 0.01, 5, 50], 
                  ["./data/bird.jpg", 0.01, 5, 50], 
                  ["./data/hockey.jpg", 0.01, 5, 50]]
    )
    iface.launch(enable_queue=True)