narugo1992 commited on
Commit
6823fd7
1 Parent(s): 9469fcd

dev(narugo): upload inception v3

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -8,12 +8,20 @@ from huggingface_hub import hf_hub_download
8
  from imgutils.data import load_image
9
  from imgutils.utils import open_onnx_model
10
 
 
 
 
 
 
 
 
 
11
 
12
  @lru_cache()
13
- def _onnx_model():
14
  return open_onnx_model(hf_hub_download(
15
  'deepghs/imgutils-models',
16
- 'nsfw/nsfwjs.onnx'
17
  ))
18
 
19
 
@@ -25,9 +33,9 @@ def _image_preprocess(image, size: int = 224) -> np.ndarray:
25
  _LABELS = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
26
 
27
 
28
- def predict(image):
29
- input_ = _image_preprocess(image).astype(np.float32)
30
- output_, = _onnx_model().run(['dense_3'], {'input_1': input_})
31
  return dict(zip(_LABELS, map(float, output_[0])))
32
 
33
 
@@ -36,6 +44,7 @@ if __name__ == '__main__':
36
  with gr.Row():
37
  with gr.Column():
38
  gr_input_image = gr.Image(type='pil', label='Original Image')
 
39
  gr_btn_submit = gr.Button(value='Tagging', variant='primary')
40
 
41
  with gr.Column():
@@ -43,7 +52,7 @@ if __name__ == '__main__':
43
 
44
  gr_btn_submit.click(
45
  predict,
46
- inputs=[gr_input_image],
47
  outputs=[gr_ratings],
48
  )
49
  demo.queue(os.cpu_count()).launch()
 
8
  from imgutils.data import load_image
9
  from imgutils.utils import open_onnx_model
10
 
11
+ _MODELS = [
12
+ ('nsfwjs.onnx', 224),
13
+ ('inception_v3.onnx', 299),
14
+ ]
15
+ _MODEL_NAMES = [name for name, _ in _MODELS]
16
+ _DEFAULT_MODEL_NAME = _MODEL_NAMES[0]
17
+ _MODEL_TO_SIZE = dict(_MODELS)
18
+
19
 
20
  @lru_cache()
21
+ def _onnx_model(name):
22
  return open_onnx_model(hf_hub_download(
23
  'deepghs/imgutils-models',
24
+ f'nsfw/{name}'
25
  ))
26
 
27
 
 
33
  _LABELS = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
34
 
35
 
36
+ def predict(image, model_name):
37
+ input_ = _image_preprocess(image, _MODEL_TO_SIZE[model_name]).astype(np.float32)
38
+ output_, = _onnx_model(model_name).run(['dense_3'], {'input_1': input_})
39
  return dict(zip(_LABELS, map(float, output_[0])))
40
 
41
 
 
44
  with gr.Row():
45
  with gr.Column():
46
  gr_input_image = gr.Image(type='pil', label='Original Image')
47
+ gr_model = gr.Dropdown(_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Model')
48
  gr_btn_submit = gr.Button(value='Tagging', variant='primary')
49
 
50
  with gr.Column():
 
52
 
53
  gr_btn_submit.click(
54
  predict,
55
+ inputs=[gr_input_image, gr_model],
56
  outputs=[gr_ratings],
57
  )
58
  demo.queue(os.cpu_count()).launch()