lilferrit commited on
Commit
c473c85
1 Parent(s): 61f6c42

implemented counting ui

Browse files
eggcount/gradient.py CHANGED
@@ -140,7 +140,8 @@ def contour_thresh(
140
  img: np.ndarray,
141
  color_thresh: int = 75,
142
  avg_area: float = 800,
143
- kernal_size: tuple[int, int] = (3, 3)
 
144
  ) -> Dict:
145
  visualization_img = img.copy()
146
  img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
@@ -158,6 +159,11 @@ def contour_thresh(
158
  area = cv2.contourArea(cnt)
159
 
160
  if area > avg_area / 2:
 
 
 
 
 
161
  cv2.drawContours(visualization_img, [cnt], -1, (255, 0, 0), 2)
162
  curr_num = round(area / avg_area)
163
  num += curr_num
 
140
  img: np.ndarray,
141
  color_thresh: int = 75,
142
  avg_area: float = 800,
143
+ kernal_size: tuple[int, int] = (3, 3),
144
+ max_eggs: Optional[int] = None
145
  ) -> Dict:
146
  visualization_img = img.copy()
147
  img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
 
159
  area = cv2.contourArea(cnt)
160
 
161
  if area > avg_area / 2:
162
+ curr_num = round(area / avg_area)
163
+
164
+ if max_eggs and (curr_num > max_eggs):
165
+ continue
166
+
167
  cv2.drawContours(visualization_img, [cnt], -1, (255, 0, 0), 2)
168
  curr_num = round(area / avg_area)
169
  num += curr_num
eggcount/pages/home.py CHANGED
@@ -1,9 +1,22 @@
1
  from dash import html, dcc, callback, Input, Output, State
2
  from dash.exceptions import PreventUpdate
3
- from typing import Tuple, Any, Dict
 
4
  from io import BytesIO
5
  from PIL import Image
6
  from pillow_heif import register_heif_opener
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import plotly.express as px
9
  import base64
@@ -16,6 +29,14 @@ dash.register_page(__name__, path = "/")
16
 
17
  UPLOAD_HEIGHT = "25vh"
18
 
 
 
 
 
 
 
 
 
19
  def get_initial_upload_container() -> dbc.Container:
20
  return dcc.Upload(
21
  id = "upload-data",
@@ -88,6 +109,25 @@ layout = dbc.Container(
88
  ],
89
  is_open = False,
90
  id = "upload-modal"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
  ],
93
  class_name = "text-center mt-3"
@@ -129,3 +169,158 @@ def on_image_upload(
129
  str(e),
130
  True
131
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dash import html, dcc, callback, Input, Output, State
2
  from dash.exceptions import PreventUpdate
3
+ from typing import Tuple, Any, Dict, Optional
4
+ from functools import partial
5
  from io import BytesIO
6
  from PIL import Image
7
  from pillow_heif import register_heif_opener
8
+ from eggcount.gradient import (
9
+ component_thesh,
10
+ component_filter_thresh,
11
+ contour_thresh
12
+ )
13
+ from eggcount.ui.ui_utils import (
14
+ get_cc_ui,
15
+ get_cc_filter_ui,
16
+ get_contour_ui,
17
+ display_slider_value,
18
+ get_results_container
19
+ )
20
 
21
  import plotly.express as px
22
  import base64
 
29
 
30
  UPLOAD_HEIGHT = "25vh"
31
 
32
+ COUNT_FUNCS = {
33
+ "Gradient CC": get_cc_ui,
34
+ "Gradient CC w/ filter": get_cc_filter_ui,
35
+ "Contour": get_contour_ui
36
+ }
37
+
38
+ DEFAULT_STRATEGY = "Gradient CC"
39
+
40
  def get_initial_upload_container() -> dbc.Container:
41
  return dcc.Upload(
42
  id = "upload-data",
 
109
  ],
110
  is_open = False,
111
  id = "upload-modal"
112
+ ),
113
+ html.H4("Select Counting Strategy", className = "text-start mt-3"),
114
+ dcc.Dropdown(
115
+ options = [name for name in COUNT_FUNCS],
116
+ value = DEFAULT_STRATEGY,
117
+ id = "strat-picker",
118
+ className = "my-2 w-100"
119
+ ),
120
+ dbc.Container(
121
+ id = "count-ui-container",
122
+ className = "mt-1 mx-0 px-0"
123
+ ),
124
+ dcc.Loading(
125
+ children = dbc.Container(
126
+ id = "count-res-container",
127
+ className = "mt-4 mx-0 px-0"
128
+ ),
129
+ type = "default",
130
+ color = "black"
131
  )
132
  ],
133
  class_name = "text-center mt-3"
 
169
  str(e),
170
  True
171
  )
172
+
173
+ @callback(
174
+ Output("count-ui-container", "children"),
175
+ Input("strat-picker", "value")
176
+ )
177
+ def on_select_strat(
178
+ curr_strat: str
179
+ ) -> Optional[dbc.Container]:
180
+ if curr_strat not in COUNT_FUNCS:
181
+ return None
182
+
183
+ ui_fun = COUNT_FUNCS[curr_strat]
184
+ return ui_fun()
185
+
186
+ @callback(
187
+ Output("count-res-container", "children", allow_duplicate = True),
188
+ Input("count-cc", "n_clicks"),
189
+ State("select-cc-color-thresh", "value"),
190
+ State("select-cc-avg-area", "value"),
191
+ State("select-cc-max-eggs", "value"),
192
+ State("img-data-store", "data"),
193
+ allow_duplicate = True,
194
+ prevent_initial_call = True
195
+ )
196
+ def on_count_cc(
197
+ n_clicks: int,
198
+ color_thresh: int,
199
+ avg_area: int,
200
+ max_eggs: Optional[int],
201
+ image_store: Dict,
202
+ ) -> dbc.Container:
203
+ if not n_clicks:
204
+ return None
205
+
206
+ decoded_bytes = base64.b64decode(image_store["img"])
207
+ image_data = BytesIO(decoded_bytes)
208
+ pil_img = Image.open(image_data)
209
+ img = np.array(pil_img)
210
+
211
+ color_thresh = int(color_thresh)
212
+ avg_area = int(avg_area)
213
+
214
+ if max_eggs:
215
+ max_eggs = int(max_eggs)
216
+
217
+ results = component_thesh(
218
+ img,
219
+ color_thresh = color_thresh,
220
+ avg_area = avg_area,
221
+ max_eggs = max_eggs
222
+ )
223
+
224
+ return get_results_container(results)
225
+
226
+ @callback(
227
+ Output("count-res-container", "children", allow_duplicate = True),
228
+ Input("count-cc-filter", "n_clicks"),
229
+ State("select-cc-filter-color-thresh", "value"),
230
+ State("select-cc-filter-avg-area", "value"),
231
+ State("select-cc-filter-max-eggs", "value"),
232
+ State("select-cc-kernel-width", "value"),
233
+ State("select-cc-kernel-height", "value"),
234
+ State("img-data-store", "data"),
235
+ prevent_initial_call = True
236
+ )
237
+ def on_count_cc(
238
+ n_clicks: int,
239
+ color_thresh: int,
240
+ avg_area: int,
241
+ max_eggs: Optional[int],
242
+ kernel_width: int,
243
+ kernel_height: int,
244
+ image_store: Dict,
245
+ ) -> dbc.Container:
246
+ if not n_clicks:
247
+ return None
248
+
249
+ decoded_bytes = base64.b64decode(image_store["img"])
250
+ image_data = BytesIO(decoded_bytes)
251
+ pil_img = Image.open(image_data)
252
+ img = np.array(pil_img)
253
+
254
+ color_thresh = int(color_thresh)
255
+ avg_area = int(avg_area)
256
+ kernel_width = int(kernel_width)
257
+ kernel_height = int(kernel_height)
258
+
259
+ if max_eggs:
260
+ max_eggs = int(max_eggs)
261
+
262
+ results = component_filter_thresh(
263
+ img,
264
+ color_thresh = color_thresh,
265
+ avg_area = avg_area,
266
+ kernal_size = (kernel_width, kernel_height),
267
+ max_eggs = max_eggs
268
+ )
269
+
270
+ return get_results_container(results)
271
+
272
+ @callback(
273
+ Output("count-res-container", "children", allow_duplicate = True),
274
+ Input("count-contour", "n_clicks"),
275
+ State("select-contour-color-thresh", "value"),
276
+ State("select-contour-avg-area", "value"),
277
+ State("select-contour-max-eggs", "value"),
278
+ State("select-contour-width", "value"),
279
+ State("select-contour-height", "value"),
280
+ State("img-data-store", "data"),
281
+ prevent_initial_call = True
282
+ )
283
+ def on_count_contour(
284
+ n_clicks: int,
285
+ color_thresh: int,
286
+ avg_area: int,
287
+ max_eggs: Optional[int],
288
+ kernel_width: int,
289
+ kernel_height: int,
290
+ image_store: Dict,
291
+ ) -> dbc.Container:
292
+ if not n_clicks:
293
+ return None
294
+
295
+ decoded_bytes = base64.b64decode(image_store["img"])
296
+ image_data = BytesIO(decoded_bytes)
297
+ pil_img = Image.open(image_data)
298
+ img = np.array(pil_img)
299
+
300
+ color_thresh = int(color_thresh)
301
+ avg_area = int(avg_area)
302
+ kernel_width = int(kernel_width)
303
+ kernel_height = int(kernel_height)
304
+
305
+ if max_eggs:
306
+ max_eggs = int(max_eggs)
307
+
308
+ results = contour_thresh(
309
+ img,
310
+ color_thresh = color_thresh,
311
+ avg_area = avg_area,
312
+ kernal_size = (kernel_width, kernel_height),
313
+ max_eggs = max_eggs
314
+ )
315
+
316
+ return get_results_container(results)
317
+
318
+ callback(
319
+ Output("display-cc-color-thresh", "children"),
320
+ Input("select-cc-color-thresh", "value")
321
+ )(partial(display_slider_value, "Color Threshold"))
322
+
323
+ callback(
324
+ Output("display-cc-filter-color-thresh", "children"),
325
+ Input("select-cc-filter-color-thresh", "value")
326
+ )(partial(display_slider_value, "Color Threshold"))
eggcount/ui/ui_utils.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import dash_bootstrap_components as dbc
2
 
3
  NAVBAR_MIN_HEIGHT = "4rem"
@@ -5,6 +10,16 @@ NAVBAR_MIN_HEIGHT = "4rem"
5
  def get_navbar() -> dbc.Nav:
6
  return dbc.Nav(
7
  children = [
 
 
 
 
 
 
 
 
 
 
8
  dbc.NavItem(
9
  children = dbc.NavLink(
10
  children = "Home",
@@ -27,8 +42,220 @@ def get_navbar() -> dbc.Nav:
27
  )
28
  )
29
  ],
30
- class_name = "bg-dark d-flex flex-row justify-content-start align-items-center",
31
  style = {
32
  "min-height": NAVBAR_MIN_HEIGHT
33
  }
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dash import html, dcc
2
+ from typing import Dict
3
+
4
+ import plotly.express as px
5
+ import numpy as np
6
  import dash_bootstrap_components as dbc
7
 
8
  NAVBAR_MIN_HEIGHT = "4rem"
 
10
  def get_navbar() -> dbc.Nav:
11
  return dbc.Nav(
12
  children = [
13
+ html.A(
14
+ dbc.Row(
15
+ [
16
+ dbc.Col(html.Img(src = "/assets/mosquito-white.png", height = "30px")),
17
+ ],
18
+ align = "center",
19
+ className = "g-0 px-2",
20
+ ),
21
+ style = {"textDecoration": "none"},
22
+ ),
23
  dbc.NavItem(
24
  children = dbc.NavLink(
25
  children = "Home",
 
42
  )
43
  )
44
  ],
45
+ className = "bg-dark d-flex flex-row justify-content-start align-items-center",
46
  style = {
47
  "min-height": NAVBAR_MIN_HEIGHT
48
  }
49
+ )
50
+
51
+ def display_slider_value(name: str, value: int | float) -> str:
52
+ return f"{name}: {value}"
53
+
54
+ def get_cc_ui() -> dbc.Container:
55
+ return dbc.Container(
56
+ children = [
57
+ html.H5(
58
+ children = "Color Threshold (0 - 255)",
59
+ className = "text-start my-3"
60
+ ),
61
+ dcc.Slider(
62
+ min = 0,
63
+ max = 255,
64
+ value = 75,
65
+ id = "select-cc-color-thresh",
66
+ className = "my-1"
67
+ ),
68
+ html.P(
69
+ children = display_slider_value("Color Threshold", 75),
70
+ id = "display-cc-color-thresh"
71
+ ),
72
+ html.H5(
73
+ children = "Average Area",
74
+ className = "text-start my-3"
75
+ ),
76
+ dcc.Input(
77
+ value = 800,
78
+ type = "number",
79
+ id = "select-cc-avg-area",
80
+ className = "w-50"
81
+ ),
82
+ html.H5(
83
+ children = "Max Eggs (optional)",
84
+ className = "text-start my-3 mt-3"
85
+ ),
86
+ dcc.Input(
87
+ type = "number",
88
+ id = "select-cc-max-eggs",
89
+ className = "w-50"
90
+ ),
91
+ dbc.Button(
92
+ children = "Count",
93
+ color = "primary",
94
+ className = "w-25 mt-4",
95
+ id = "count-cc"
96
+ )
97
+ ],
98
+ className = "p-3 m-0 border border-dark d-flex flex-column justify-content-center align-items-left"
99
+ )
100
+
101
+ def get_cc_filter_ui() -> dbc.Container:
102
+ return dbc.Container(
103
+ children = [
104
+ html.H5(
105
+ children = "Color Threshold (0 - 255)",
106
+ className = "text-start my-3"
107
+ ),
108
+ dcc.Slider(
109
+ min = 0,
110
+ max = 255,
111
+ value = 75,
112
+ id = "select-cc-filter-color-thresh",
113
+ className = "my-1"
114
+ ),
115
+ html.P(
116
+ children = display_slider_value("Color Threshold", 75),
117
+ id = "display-cc-filter-color-thresh"
118
+ ),
119
+ html.H5(
120
+ children = "Average Area",
121
+ className = "text-start my-3"
122
+ ),
123
+ dcc.Input(
124
+ value = 800,
125
+ type = "number",
126
+ id = "select-cc-filter-avg-area",
127
+ className = "w-50"
128
+ ),
129
+ html.H5(
130
+ children = "Max Eggs (optional)",
131
+ className = "text-start my-3 mt-3"
132
+ ),
133
+ dcc.Input(
134
+ type = "number",
135
+ id = "select-cc-filter-max-eggs",
136
+ className = "w-50"
137
+ ),
138
+ html.H5(
139
+ children = "Kernel Width",
140
+ className = "text-start my-3"
141
+ ),
142
+ dcc.Input(
143
+ value = 3,
144
+ type = "number",
145
+ id = "select-cc-kernel-width",
146
+ className = "w-50"
147
+ ),
148
+ html.H5(
149
+ children = "Kernel Height",
150
+ className = "text-start my-3"
151
+ ),
152
+ dcc.Input(
153
+ value = 3,
154
+ type = "number",
155
+ id = "select-cc-kernel-height",
156
+ className = "w-50"
157
+ ),
158
+ dbc.Button(
159
+ children = "Count",
160
+ color = "primary",
161
+ className = "w-25 mt-4",
162
+ id = "count-cc-filter"
163
+ )
164
+ ],
165
+ className = "p-3 m-0 border border-dark d-flex flex-column justify-content-center align-items-left"
166
+ )
167
+
168
+ def get_contour_ui() -> dbc.Container:
169
+ return dbc.Container(
170
+ children = [
171
+ html.H5(
172
+ children = "Color Threshold (0 - 255)",
173
+ className = "text-start my-3"
174
+ ),
175
+ dcc.Slider(
176
+ min = 0,
177
+ max = 255,
178
+ value = 75,
179
+ id = "select-contour-color-thresh",
180
+ className = "my-1"
181
+ ),
182
+ html.P(
183
+ children = display_slider_value("Color Threshold", 75),
184
+ id = "display-contour-color-thresh"
185
+ ),
186
+ html.H5(
187
+ children = "Average Area",
188
+ className = "text-start my-3"
189
+ ),
190
+ dcc.Input(
191
+ value = 800,
192
+ type = "number",
193
+ id = "select-contour-avg-area",
194
+ className = "w-50"
195
+ ),
196
+ html.H5(
197
+ children = "Max Eggs (optional)",
198
+ className = "text-start my-3 mt-3"
199
+ ),
200
+ dcc.Input(
201
+ type = "number",
202
+ id = "select-contour-max-eggs",
203
+ className = "w-50"
204
+ ),
205
+ html.H5(
206
+ children = "Kernel Width",
207
+ className = "text-start my-3"
208
+ ),
209
+ dcc.Input(
210
+ value = 3,
211
+ type = "number",
212
+ id = "select-contour-width",
213
+ className = "w-50"
214
+ ),
215
+ html.H5(
216
+ children = "Kernel Height",
217
+ className = "text-start my-3"
218
+ ),
219
+ dcc.Input(
220
+ value = 3,
221
+ type = "number",
222
+ id = "select-contour-height",
223
+ className = "w-50"
224
+ ),
225
+ dbc.Button(
226
+ children = "Count",
227
+ color = "primary",
228
+ className = "w-25 mt-4",
229
+ id = "count-contour"
230
+ )
231
+ ],
232
+ className = "p-3 m-0 border border-dark d-flex flex-column justify-content-center align-items-left"
233
+ )
234
+
235
+ def get_results_container(result: Dict) -> dbc.Container:
236
+ children = list()
237
+
238
+ for stat_name, stat in result["stats"].items():
239
+ children.append(
240
+ html.H5(
241
+ children = f"{stat_name.replace('-', ' ')}: {stat}",
242
+ className = "text-start w-100"
243
+ ),
244
+ )
245
+
246
+ children.append(html.Hr(className = "border border-dark"))
247
+
248
+ for vis_name, vis_pic in result["vis"].items():
249
+ children.append(html.H4(vis_name.replace("-", " ")))
250
+ image_fig = px.imshow(vis_pic)
251
+ children.append(
252
+ dcc.Graph(
253
+ figure = image_fig,
254
+ className = "w-100"
255
+ )
256
+ )
257
+
258
+ return dbc.Container(
259
+ children = children,
260
+ className = "p-3 m-0 border border-dark d-flex flex-column justify-content-center align-items-center"
261
+ )