dawn17 commited on
Commit
449d5f0
·
1 Parent(s): cd6f1c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -31,16 +31,16 @@ transform = transforms.Compose([
31
  inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
32
 
33
  classes = [
34
- "plane",
35
- "car",
36
- "bird",
37
- "cat",
38
- "deer",
39
- "dog",
40
- "frog",
41
- "horse",
42
- "ship",
43
- "truck",
44
  ]
45
 
46
 
@@ -128,8 +128,8 @@ class Gradio:
128
 
129
  method = Gradio(model_path="./checkpoint/model.pt")
130
  demo = gr.Interface(
131
- method.inference,
132
- [
133
  gr.Image(shape=(32, 32), label="Input Image", value="./samples/bird_and_plane_.jpeg"),
134
  gr.Slider(
135
  minimum=0,
@@ -160,10 +160,11 @@ demo = gr.Interface(
160
  info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
161
  ),
162
  ],
163
- [
164
  gr.Image(shape=(32, 32)).style(width=128, height=128),
165
  gr.Label(label="Top Classes"),
166
  ],
 
167
  )
168
 
169
 
 
31
  inv_normalize = transforms.Normalize(mean=inv_mean, std=inv_std)
32
 
33
  classes = [
34
+ 'airplane',
35
+ 'automobile',
36
+ 'bird',
37
+ 'cat',
38
+ 'deer',
39
+ 'dog',
40
+ 'frog',
41
+ 'horse',
42
+ 'ship',
43
+ 'truck'
44
  ]
45
 
46
 
 
128
 
129
  method = Gradio(model_path="./checkpoint/model.pt")
130
  demo = gr.Interface(
131
+ fn=method.inference,
132
+ inputs=[
133
  gr.Image(shape=(32, 32), label="Input Image", value="./samples/bird_and_plane_.jpeg"),
134
  gr.Slider(
135
  minimum=0,
 
160
  info="This section showcases the specific region of interest within the input image that the Class Activation Map (CAM) algorithm emphasizes to make predictions based on the selected class from the dropdown menu. The 'default' value serves as the default choice, representing the top class predicted by the model.",
161
  ),
162
  ],
163
+ outputs=[
164
  gr.Image(shape=(32, 32)).style(width=128, height=128),
165
  gr.Label(label="Top Classes"),
166
  ],
167
+ examples=[[os.path.join("./samples/", f)] for f in os.listdir("./samples/")]
168
  )
169
 
170