dawn17 commited on
Commit
bc264b9
·
1 Parent(s): 6ef1f60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -90,10 +90,10 @@ class Gradio:
90
  def inference(
91
  self,
92
  input_img: np.array,
93
- transparency: float,
94
- ntop_classes: int,
95
- layer_nums: List,
96
- cam_for_class: str,
97
  ):
98
  self.model.eval()
99
  input_img = transform(input_img)
@@ -110,7 +110,7 @@ class Gradio:
110
 
111
  class_id = (
112
  prediction[0][0]
113
- if cam_for_class in ["default", "", None]
114
  else classes.index(cam_for_class)
115
  )
116
  visualization = self.cam(
@@ -164,7 +164,7 @@ demo = gr.Interface(
164
  gr.Image(shape=(32, 32)).style(width=128, height=128),
165
  gr.Label(label="Top Classes"),
166
  ],
167
- examples=[[os.path.join("samples/", f)] for f in os.listdir("samples/")]
168
  )
169
 
170
 
 
90
  def inference(
91
  self,
92
  input_img: np.array,
93
+ transparency: float=0.7,
94
+ ntop_classes: int=2,
95
+ layer_nums: List=[3, 4],
96
+ cam_for_class: str="default",
97
  ):
98
  self.model.eval()
99
  input_img = transform(input_img)
 
110
 
111
  class_id = (
112
  prediction[0][0]
113
+ if cam_for_class in ["default", ""]
114
  else classes.index(cam_for_class)
115
  )
116
  visualization = self.cam(
 
164
  gr.Image(shape=(32, 32)).style(width=128, height=128),
165
  gr.Label(label="Top Classes"),
166
  ],
167
+ examples=[[os.path.join("./samples/", f)] for f in os.listdir("samples/")]
168
  )
169
 
170