ok
Browse files
app.py
CHANGED
@@ -2,12 +2,29 @@ import streamlit as st
|
|
2 |
import json
|
3 |
import random
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
@st.cache_resource
|
6 |
def load_model():
|
7 |
import adrd
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
return model
|
12 |
|
13 |
@st.cache_resource
|
@@ -37,28 +54,39 @@ dat_tst = CSVDataset(
|
|
37 |
if 'input_text' not in st.session_state:
|
38 |
st.session_state.input_text = ""
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
if sample_button:
|
59 |
idx = random.randint(0, len(dat_tst) - 1)
|
60 |
-
|
61 |
-
st.session_state.input_text = json.dumps(
|
62 |
|
63 |
# reset input text after form processing to show updated text in the input box
|
64 |
if 'input_text' in st.session_state:
|
@@ -69,8 +97,12 @@ elif submit_button:
|
|
69 |
# Parse the JSON input into a Python dictionary
|
70 |
data_dict = json.loads(json_input)
|
71 |
pred_dict = predict_proba(data_dict)
|
72 |
-
|
73 |
-
|
|
|
74 |
except json.JSONDecodeError as e:
|
75 |
# Handle JSON parsing errors
|
76 |
st.error(f"An error occurred: {e}")
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
import random
|
4 |
|
5 |
+
# set page configuration to wide mode
|
6 |
+
st.set_page_config(layout="wide")
|
7 |
+
|
8 |
+
st.markdown("""
|
9 |
+
<style>
|
10 |
+
.bounding-box {
|
11 |
+
border: 2px solid #4CAF50; # Green border
|
12 |
+
border-radius: 5px; # Rounded corners
|
13 |
+
padding: 10px; # Padding inside the box
|
14 |
+
margin: 10px; # Space outside the box
|
15 |
+
}
|
16 |
+
</style>
|
17 |
+
""", unsafe_allow_html=True)
|
18 |
+
|
19 |
@st.cache_resource
|
20 |
def load_model():
|
21 |
import adrd
|
22 |
+
try:
|
23 |
+
ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
|
24 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
25 |
+
except:
|
26 |
+
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
|
27 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
28 |
return model
|
29 |
|
30 |
@st.cache_resource
|
|
|
54 |
if 'input_text' not in st.session_state:
|
55 |
st.session_state.input_text = ""
|
56 |
|
57 |
+
# section 1
|
58 |
+
st.markdown("#### About ADRD")
|
59 |
+
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
|
60 |
+
|
61 |
+
# section 2
|
62 |
+
st.markdown("#### Demo")
|
63 |
+
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random NACC Case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
|
|
|
64 |
|
65 |
+
# layout
|
66 |
+
layout_l, layout_r = st.columns([1, 1])
|
67 |
|
68 |
+
# create a form for user input
|
69 |
+
with layout_l:
|
70 |
+
with st.form("json_input_form"):
|
71 |
+
json_input = st.text_area(
|
72 |
+
"Please enter JSON-formatted input features:",
|
73 |
+
value = st.session_state.input_text,
|
74 |
+
height = 250
|
75 |
+
)
|
76 |
|
77 |
+
# create three columns
|
78 |
+
left_col, middle_col, right_col = st.columns([3, 4, 1])
|
79 |
|
80 |
+
with left_col:
|
81 |
+
sample_button = st.form_submit_button("Random NACC Case")
|
82 |
+
|
83 |
+
with right_col:
|
84 |
+
submit_button = st.form_submit_button("Predict")
|
85 |
+
|
86 |
if sample_button:
|
87 |
idx = random.randint(0, len(dat_tst) - 1)
|
88 |
+
random_case = dat_tst[idx][0]
|
89 |
+
st.session_state.input_text = json.dumps(random_case, indent=2)
|
90 |
|
91 |
# reset input text after form processing to show updated text in the input box
|
92 |
if 'input_text' in st.session_state:
|
|
|
97 |
# Parse the JSON input into a Python dictionary
|
98 |
data_dict = json.loads(json_input)
|
99 |
pred_dict = predict_proba(data_dict)
|
100 |
+
with layout_r:
|
101 |
+
st.write("Predicted probabilities:")
|
102 |
+
st.json(pred_dict)
|
103 |
except json.JSONDecodeError as e:
|
104 |
# Handle JSON parsing errors
|
105 |
st.error(f"An error occurred: {e}")
|
106 |
+
|
107 |
+
# section 3
|
108 |
+
|