hysts HF staff commited on
Commit
bbe49e5
·
1 Parent(s): d4cb7c9
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +32 -13
  3. requirements.txt +3 -2
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🏃
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -47,8 +47,9 @@ model = load_model()
47
  labels = load_labels()
48
 
49
 
50
- def predict(image: PIL.Image.Image,
51
- score_threshold: float) -> dict[str, float]:
 
52
  _, height, width, _ = model.input_shape
53
  image = np.asarray(image)
54
  image = tf.image.resize(image,
@@ -60,12 +61,19 @@ def predict(image: PIL.Image.Image,
60
  image = image / 255.
61
  probs = model.predict(image[None, ...])[0]
62
  probs = probs.astype(float)
63
- res = dict()
64
- for prob, label in zip(probs.tolist(), labels):
 
 
 
 
 
 
65
  if prob < score_threshold:
66
- continue
67
- res[label] = prob
68
- return res
 
69
 
70
 
71
  image_paths = load_sample_image_paths()
@@ -83,15 +91,26 @@ with gr.Blocks(css='style.css') as demo:
83
  value=0.5)
84
  run_button = gr.Button('Run')
85
  with gr.Column():
86
- result = gr.Label(label='Output')
 
 
 
 
 
 
 
 
 
87
  gr.Examples(examples=examples,
88
  inputs=[image, score_threshold],
89
- outputs=result,
90
  fn=predict,
91
  cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
92
 
93
- run_button.click(fn=predict,
94
- inputs=[image, score_threshold],
95
- outputs=result,
96
- api_name='predict')
 
 
97
  demo.queue().launch()
 
47
  labels = load_labels()
48
 
49
 
50
+ def predict(
51
+ image: PIL.Image.Image, score_threshold: float
52
+ ) -> tuple[dict[str, float], dict[str, float], str]:
53
  _, height, width, _ = model.input_shape
54
  image = np.asarray(image)
55
  image = tf.image.resize(image,
 
61
  image = image / 255.
62
  probs = model.predict(image[None, ...])[0]
63
  probs = probs.astype(float)
64
+
65
+ indices = np.argsort(probs)[::-1]
66
+ result_all = dict()
67
+ result_threshold = dict()
68
+ for index in indices:
69
+ label = labels[index]
70
+ prob = probs[index]
71
+ result_all[label] = prob
72
  if prob < score_threshold:
73
+ break
74
+ result_threshold[label] = prob
75
+ result_text = ', '.join(result_all.keys())
76
+ return result_threshold, result_all, result_text
77
 
78
 
79
  image_paths = load_sample_image_paths()
 
91
  value=0.5)
92
  run_button = gr.Button('Run')
93
  with gr.Column():
94
+ with gr.Tabs():
95
+ with gr.Tab(label='Output'):
96
+ result = gr.Label(label='Output', show_label=False)
97
+ with gr.Tab(label='JSON'):
98
+ result_json = gr.JSON(label='JSON output',
99
+ show_label=False)
100
+ with gr.Tab(label='Text'):
101
+ result_text = gr.Text(label='Text output',
102
+ show_label=False,
103
+ lines=5)
104
  gr.Examples(examples=examples,
105
  inputs=[image, score_threshold],
106
+ outputs=[result, result_json, result_text],
107
  fn=predict,
108
  cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
109
 
110
+ run_button.click(
111
+ fn=predict,
112
+ inputs=[image, score_threshold],
113
+ outputs=[result, result_json, result_text],
114
+ api_name='predict',
115
+ )
116
  demo.queue().launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- pillow>=9.0.0
2
- tensorflow>=2.7.0
3
  git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru
 
 
 
 
 
 
1
  git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru
2
+ pillow==10.0.0
3
+ pydantic==1.10.11
4
+ tensorflow==2.13.0