Spaces:
Runtime error
Runtime error
from dash import html, dcc, callback, Input, Output, State | |
from dash.exceptions import PreventUpdate | |
from typing import Tuple, Any, Dict, Optional | |
from functools import partial | |
from io import BytesIO | |
from PIL import Image | |
from pillow_heif import register_heif_opener | |
from larvaecount.gradient import ( | |
component_thesh, | |
component_filter_thresh, | |
contour_thresh | |
) | |
from larvaecount.ui.ui_utils import ( | |
get_cc_ui, | |
get_cc_filter_ui, | |
get_contour_ui, | |
display_slider_value, | |
get_results_container | |
) | |
import plotly.express as px | |
import base64 | |
import dash | |
import dash_bootstrap_components as dbc | |
import numpy as np | |
register_heif_opener() | |
dash.register_page(__name__, path = "/") | |
UPLOAD_HEIGHT = "25vh" | |
COUNT_FUNCS = { | |
"Gradient CC": get_cc_ui, | |
"Gradient CC w/ Filter": get_cc_filter_ui, | |
"Gradient Contour w/ Filter": get_contour_ui | |
} | |
DEFAULT_STRATEGY = "Gradient CC w/ Filter" | |
def get_initial_upload_container() -> dbc.Container: | |
return dcc.Upload( | |
id = "upload-data", | |
children = dbc.Container( | |
children = [ | |
html.Img( | |
src = "assets/camera.png", | |
alt = "camera-image", | |
className = "h-50" | |
), | |
html.H2("Drag and Drop or Select Image File") | |
], | |
class_name = "w-100 d-flex flex-column justify-content-center align-items-center", | |
style = { | |
"height": UPLOAD_HEIGHT | |
} | |
) | |
) | |
def get_new_upload_container( | |
image_b64: str, | |
file_name: str | |
) -> dbc.Container: | |
decoded_bytes = base64.b64decode(image_b64) | |
image_data = BytesIO(decoded_bytes) | |
pil_img = Image.open(image_data) | |
img = np.array(pil_img) | |
image_fig = px.imshow( | |
img, | |
height = 750, | |
) | |
return dbc.Container( | |
children = [ | |
html.H3( | |
children = file_name, | |
className = "p-2 text-start", | |
), | |
dcc.Graph( | |
figure = image_fig, | |
), | |
dbc.Container( | |
children = dcc.Upload( | |
children = dbc.Button( | |
children = "Upload New Image", | |
color = "secondary" | |
), | |
id = "upload-data" | |
), | |
class_name = "w-100 pb-4 d-flex flex-row justify-content-center align-items-center" | |
) | |
] | |
) | |
layout = dbc.Container( | |
children = [ | |
dbc.Container( | |
children = dcc.Loading( | |
children = get_initial_upload_container(), | |
id = "image-upload-container", | |
color = "black" | |
), | |
class_name = "m-0 p-0 border border-dark" | |
), | |
dcc.Store( | |
id = "img-data-store", | |
storage_type = "memory" | |
), | |
dbc.Modal( | |
children = [ | |
dbc.ModalHeader( | |
dbc.ModalTitle("Error Processing Image File") | |
), | |
dbc.ModalBody( | |
children = "", | |
id = "upload-modal-content" | |
), | |
], | |
is_open = False, | |
id = "upload-modal" | |
), | |
html.H4("Select Counting Strategy", className = "text-start mt-3"), | |
dcc.Dropdown( | |
options = [name for name in COUNT_FUNCS], | |
value = DEFAULT_STRATEGY, | |
id = "strat-picker", | |
className = "my-2 w-100" | |
), | |
dbc.Container( | |
id = "count-ui-container", | |
className = "mt-1 mx-0 px-0" | |
), | |
dcc.Loading( | |
children = dbc.Container( | |
id = "count-res-container", | |
className = "mt-4 mx-0 px-0" | |
), | |
type = "default", | |
color = "black" | |
) | |
], | |
class_name = "text-center mt-3" | |
) | |
def on_image_upload( | |
upload_image_data: str, | |
upload_image_name: str, | |
curr_upload_chidren: Any, | |
curr_img_store_data: Dict, | |
) -> Tuple[dbc.Container, Dict, str, bool]: | |
if not upload_image_data: | |
raise PreventUpdate | |
try: | |
content_type, content_string = upload_image_data.split(',') | |
next_children = get_new_upload_container(content_string, upload_image_name) | |
return ( | |
next_children, | |
{"img": content_string}, | |
"", | |
False | |
) | |
except Exception as e: | |
return ( | |
curr_upload_chidren, | |
curr_img_store_data, | |
str(e), | |
True | |
) | |
def on_select_strat( | |
curr_strat: str | |
) -> Optional[dbc.Container]: | |
if curr_strat not in COUNT_FUNCS: | |
return None | |
ui_fun = COUNT_FUNCS[curr_strat] | |
return ui_fun() | |
def on_count_cc( | |
n_clicks: int, | |
color_thresh: int, | |
avg_area: int, | |
max_eggs: Optional[int], | |
image_store: Dict, | |
) -> dbc.Container: | |
if not n_clicks: | |
return None | |
decoded_bytes = base64.b64decode(image_store["img"]) | |
image_data = BytesIO(decoded_bytes) | |
pil_img = Image.open(image_data) | |
img = np.array(pil_img) | |
color_thresh = int(color_thresh) | |
avg_area = int(avg_area) | |
if max_eggs: | |
max_eggs = int(max_eggs) | |
results = component_thesh( | |
img, | |
color_thresh = color_thresh, | |
avg_area = avg_area, | |
max_eggs = max_eggs | |
) | |
return get_results_container(results) | |
def on_count_cc( | |
n_clicks: int, | |
color_thresh: int, | |
avg_area: int, | |
max_eggs: Optional[int], | |
kernel_width: int, | |
kernel_height: int, | |
image_store: Dict, | |
) -> dbc.Container: | |
if not n_clicks: | |
return None | |
decoded_bytes = base64.b64decode(image_store["img"]) | |
image_data = BytesIO(decoded_bytes) | |
pil_img = Image.open(image_data) | |
img = np.array(pil_img) | |
color_thresh = int(color_thresh) | |
avg_area = int(avg_area) | |
kernel_width = int(kernel_width) | |
kernel_height = int(kernel_height) | |
if max_eggs: | |
max_eggs = int(max_eggs) | |
results = component_filter_thresh( | |
img, | |
color_thresh = color_thresh, | |
avg_area = avg_area, | |
kernal_size = (kernel_width, kernel_height), | |
max_eggs = max_eggs | |
) | |
return get_results_container(results) | |
def on_count_contour( | |
n_clicks: int, | |
color_thresh: int, | |
avg_area: int, | |
max_eggs: Optional[int], | |
kernel_width: int, | |
kernel_height: int, | |
image_store: Dict, | |
) -> dbc.Container: | |
if not n_clicks: | |
return None | |
decoded_bytes = base64.b64decode(image_store["img"]) | |
image_data = BytesIO(decoded_bytes) | |
pil_img = Image.open(image_data) | |
img = np.array(pil_img) | |
color_thresh = int(color_thresh) | |
avg_area = int(avg_area) | |
kernel_width = int(kernel_width) | |
kernel_height = int(kernel_height) | |
if max_eggs: | |
max_eggs = int(max_eggs) | |
results = contour_thresh( | |
img, | |
color_thresh = color_thresh, | |
avg_area = avg_area, | |
kernal_size = (kernel_width, kernel_height), | |
max_eggs = max_eggs | |
) | |
return get_results_container(results) | |
callback( | |
Output("display-cc-color-thresh", "children"), | |
Input("select-cc-color-thresh", "value") | |
)(partial(display_slider_value, "Color Threshold")) | |
callback( | |
Output("display-cc-filter-color-thresh", "children"), | |
Input("select-cc-filter-color-thresh", "value") | |
)(partial(display_slider_value, "Color Threshold")) | |
callback( | |
Output("display-contour-color-thresh", "children"), | |
Input("select-contour-color-thresh", "value") | |
)(partial(display_slider_value, "Color Threshold")) | |