Update app.py
Browse files
app.py
CHANGED
@@ -31,16 +31,16 @@ transform = transforms.Compose([
|
|
31 |
inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
|
32 |
|
33 |
classes = [
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
]
|
45 |
|
46 |
|
@@ -128,8 +128,8 @@ class Gradio:
|
|
128 |
|
129 |
method = Gradio(model_path="./checkpoint/model.pt")
|
130 |
demo = gr.Interface(
|
131 |
-
method.inference,
|
132 |
-
[
|
133 |
gr.Image(shape=(32, 32), label="Input Image", value="./samples/bird_and_plane_.jpeg"),
|
134 |
gr.Slider(
|
135 |
minimum=0,
|
@@ -160,10 +160,11 @@ demo = gr.Interface(
|
|
160 |
info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
|
161 |
),
|
162 |
],
|
163 |
-
[
|
164 |
gr.Image(shape=(32, 32)).style(width=128, height=128),
|
165 |
gr.Label(label="Top Classes"),
|
166 |
],
|
|
|
167 |
)
|
168 |
|
169 |
|
|
|
31 |
inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
|
32 |
|
33 |
classes = [
|
34 |
+
'airplane',
|
35 |
+
'automobile',
|
36 |
+
'bird',
|
37 |
+
'cat',
|
38 |
+
'deer',
|
39 |
+
'dog',
|
40 |
+
'frog',
|
41 |
+
'horse',
|
42 |
+
'ship',
|
43 |
+
'truck'
|
44 |
]
|
45 |
|
46 |
|
|
|
128 |
|
129 |
method = Gradio(model_path="./checkpoint/model.pt")
|
130 |
demo = gr.Interface(
|
131 |
+
fn=method.inference,
|
132 |
+
inputs=[
|
133 |
gr.Image(shape=(32, 32), label="Input Image", value="./samples/bird_and_plane_.jpeg"),
|
134 |
gr.Slider(
|
135 |
minimum=0,
|
|
|
160 |
info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
|
161 |
),
|
162 |
],
|
163 |
+
outputs=[
|
164 |
gr.Image(shape=(32, 32)).style(width=128, height=128),
|
165 |
gr.Label(label="Top Classes"),
|
166 |
],
|
167 |
+
examples=[[os.path.join("./samples/", f)] for f in os.listdir("./samples/")]
|
168 |
)
|
169 |
|
170 |
|