update
Browse files- app.py +114 -78
- app_backup_1.py +135 -0
app.py
CHANGED
@@ -3,21 +3,20 @@ import json
|
|
3 |
import random
|
4 |
import pandas as pd
|
5 |
import pickle
|
|
|
6 |
|
7 |
# set page configuration to wide mode
|
8 |
st.set_page_config(layout="wide")
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
.
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
}
|
18 |
-
</style>
|
19 |
-
""", unsafe_allow_html=True)
|
20 |
|
|
|
21 |
@st.cache_resource
|
22 |
def load_model():
|
23 |
import adrd
|
@@ -25,26 +24,21 @@ def load_model():
|
|
25 |
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
|
26 |
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
27 |
except:
|
28 |
-
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
|
29 |
-
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
|
|
30 |
return model
|
31 |
|
32 |
-
@st.cache_resource
|
33 |
-
def load_nacc_data():
|
34 |
-
from data.dataset_csv import CSVDataset
|
35 |
-
dat = CSVDataset(
|
36 |
-
dat_file = "./data/test.csv",
|
37 |
-
cnf_file = "./data/input_meta_info.csv"
|
38 |
-
)
|
39 |
-
return dat
|
40 |
-
|
41 |
model = load_model()
|
42 |
-
dat_tst = load_nacc_data()
|
43 |
|
44 |
def predict_proba(data_dict):
|
45 |
pred_dict = model.predict_proba([data_dict])[1][0]
|
46 |
return pred_dict
|
47 |
|
|
|
|
|
|
|
|
|
48 |
# load NACC testing data
|
49 |
from data.dataset_csv import CSVDataset
|
50 |
dat_tst = CSVDataset(
|
@@ -52,42 +46,98 @@ dat_tst = CSVDataset(
|
|
52 |
cnf_file = "./data/input_meta_info.csv"
|
53 |
)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
st.
|
65 |
-
st.
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
)
|
78 |
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
|
89 |
nacc_mapping = pickle.load(file)
|
90 |
-
|
91 |
def convert_dictionary(original_dict, mappings):
|
92 |
transformed_dict = {}
|
93 |
|
@@ -104,31 +154,17 @@ def convert_dictionary(original_dict, mappings):
|
|
104 |
transformed_dict[new_key] = transformed_value
|
105 |
|
106 |
return transformed_dict
|
107 |
-
|
108 |
-
if sample_button:
|
109 |
-
idx = random.randint(0, len(dat_tst) - 1)
|
110 |
-
random_case = dat_tst[idx][0]
|
111 |
-
st.session_state.input_text = json.dumps(random_case, indent=2)
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
pred_dict = predict_proba(data_dict)
|
124 |
-
with layout_r:
|
125 |
-
st.write("Predicted probabilities:")
|
126 |
-
st.code(json.dumps(pred_dict, indent=2))
|
127 |
-
except json.JSONDecodeError as e:
|
128 |
-
# Handle JSON parsing errors
|
129 |
-
st.error(f"An error occurred: {e}")
|
130 |
-
|
131 |
-
# section 3
|
132 |
-
st.markdown("#### Feature Table")
|
133 |
-
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
|
134 |
-
st.table(df_input_meta_info)
|
|
|
3 |
import random
|
4 |
import pandas as pd
|
5 |
import pickle
|
6 |
+
import json
|
7 |
|
8 |
# set page configuration to wide mode
|
9 |
st.set_page_config(layout="wide")
|
10 |
|
11 |
+
# section 1
|
12 |
+
st.markdown("#### About")
|
13 |
+
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
|
14 |
+
|
15 |
+
# section 2
|
16 |
+
st.markdown("#### Demo")
|
17 |
+
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**RANDOM EXAMPLE**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**PREDICT**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
|
|
|
|
|
|
|
18 |
|
19 |
+
# load model
|
20 |
@st.cache_resource
|
21 |
def load_model():
|
22 |
import adrd
|
|
|
24 |
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
|
25 |
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
26 |
except:
|
27 |
+
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
|
28 |
+
# model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
29 |
+
return None
|
30 |
return model
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
model = load_model()
|
|
|
33 |
|
34 |
def predict_proba(data_dict):
|
35 |
pred_dict = model.predict_proba([data_dict])[1][0]
|
36 |
return pred_dict
|
37 |
|
38 |
+
# load meta data csv
|
39 |
+
file_path = './data/input_meta_info.csv'
|
40 |
+
input_meta_info = pd.read_csv(file_path)
|
41 |
+
|
42 |
# load NACC testing data
|
43 |
from data.dataset_csv import CSVDataset
|
44 |
dat_tst = CSVDataset(
|
|
|
46 |
cnf_file = "./data/input_meta_info.csv"
|
47 |
)
|
48 |
|
49 |
+
def get_random_example():
|
50 |
+
idx = random.randint(0, len(dat_tst) - 1)
|
51 |
+
random_case = dat_tst[idx][0]
|
52 |
+
return random_case
|
53 |
+
|
54 |
+
# Get random example features if the button is clicked
|
55 |
+
if 'random_example' not in st.session_state:
|
56 |
+
st.session_state.random_example = None
|
57 |
+
|
58 |
+
if st.button("RANDOM EXAMPLE"):
|
59 |
+
st.session_state.random_example = get_random_example()
|
60 |
+
st.rerun()
|
61 |
+
|
62 |
+
random_example = st.session_state.random_example
|
63 |
+
|
64 |
+
def create_input(i):
|
65 |
+
row = input_meta_info.iloc[i]
|
66 |
+
name = row['Name']
|
67 |
+
description = row['Description']
|
68 |
+
|
69 |
+
# dirty work, inspect keys and values
|
70 |
+
values = row['Values']
|
71 |
+
values = values.replace('\'', '\"')
|
72 |
+
values = values.replace('\"0\": nan, ', '')
|
73 |
+
values = json.loads(values)
|
74 |
+
|
75 |
+
for k, v in list(values.items()):
|
76 |
+
if v == 'Unknown':
|
77 |
+
values.pop(k)
|
78 |
+
elif k in ('9', '99', '999'):
|
79 |
+
values.pop(k)
|
80 |
+
|
81 |
+
# get default value from random example if available
|
82 |
+
default_value = random_example[name] if random_example and name in random_example else None
|
83 |
+
if type(default_value) is float:
|
84 |
+
default_value = int(default_value)
|
85 |
+
|
86 |
+
# Determine the type of widget based on values
|
87 |
+
if 'range' in values:
|
88 |
+
if ' - ' in values['range']:
|
89 |
+
min_value, max_value = map(float, values['range'].split(' - '))
|
90 |
+
min_value, max_value = int(min_value), int(max_value)
|
91 |
+
|
92 |
+
if default_value is not None:
|
93 |
+
if default_value > max_value or default_value < min_value:
|
94 |
+
default_value = None
|
95 |
+
|
96 |
+
st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range'])
|
97 |
+
else:
|
98 |
+
min_value = int(values['range'].replace('>= ', ''))
|
99 |
+
if default_value is not None:
|
100 |
+
if default_value < min_value or default_value == 8888:
|
101 |
+
default_value = None
|
102 |
+
|
103 |
+
st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
|
104 |
+
else:
|
105 |
+
values = {int(k): v for k, v in values.items()}
|
106 |
+
reverse_mapping = {v: k for k, v in values.items()}
|
107 |
+
if default_value in values:
|
108 |
+
default_index = list(values.keys()).index(default_value)
|
109 |
+
else:
|
110 |
+
default_index = None
|
111 |
+
|
112 |
+
st.selectbox(
|
113 |
+
description,
|
114 |
+
options = values.keys(),
|
115 |
+
key = name,
|
116 |
+
index = default_index,
|
117 |
+
format_func=lambda x: values[x]
|
118 |
)
|
119 |
|
120 |
+
# create form
|
121 |
+
with st.form("dynamic_form"):
|
122 |
+
# random_example_button = st.form_submit_button("RANDOM EXAMPLE")
|
123 |
|
124 |
+
cols = st.columns(3)
|
125 |
+
with cols[0]:
|
126 |
+
for i in range(0, len(input_meta_info), 3):
|
127 |
+
create_input(i)
|
128 |
+
with cols[1]:
|
129 |
+
for i in range(1, len(input_meta_info), 3):
|
130 |
+
create_input(i)
|
131 |
+
with cols[2]:
|
132 |
+
for i in range(2, len(input_meta_info), 3):
|
133 |
+
create_input(i)
|
134 |
|
135 |
+
predict_button = st.form_submit_button("PREDICT")
|
136 |
+
|
137 |
+
# load mapping
|
138 |
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
|
139 |
nacc_mapping = pickle.load(file)
|
140 |
+
|
141 |
def convert_dictionary(original_dict, mappings):
|
142 |
transformed_dict = {}
|
143 |
|
|
|
154 |
transformed_dict[new_key] = transformed_value
|
155 |
|
156 |
return transformed_dict
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
+
if predict_button:
|
159 |
+
# get form input
|
160 |
+
names = input_meta_info['Name'].tolist()
|
161 |
+
data_dict = {}
|
162 |
+
for name in names:
|
163 |
+
data_dict[name] = st.session_state[name]
|
164 |
|
165 |
+
# convert
|
166 |
+
data_dict = convert_dictionary(data_dict, nacc_mapping)
|
167 |
+
pred_dict = predict_proba(data_dict)
|
168 |
+
|
169 |
+
st.write("Predicted probabilities:")
|
170 |
+
st.code(json.dumps(pred_dict, indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_backup_1.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import pandas as pd
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
# set page configuration to wide mode
|
8 |
+
st.set_page_config(layout="wide")
|
9 |
+
|
10 |
+
st.markdown("""
|
11 |
+
<style>
|
12 |
+
.bounding-box {
|
13 |
+
border: 2px solid #4CAF50; # Green border
|
14 |
+
border-radius: 5px; # Rounded corners
|
15 |
+
padding: 10px; # Padding inside the box
|
16 |
+
margin: 10px; # Space outside the box
|
17 |
+
}
|
18 |
+
</style>
|
19 |
+
""", unsafe_allow_html=True)
|
20 |
+
|
21 |
+
@st.cache_resource
|
22 |
+
def load_model():
|
23 |
+
import adrd
|
24 |
+
try:
|
25 |
+
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
|
26 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
27 |
+
except:
|
28 |
+
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
|
29 |
+
# model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
30 |
+
return None
|
31 |
+
return model
|
32 |
+
|
33 |
+
@st.cache_resource
|
34 |
+
def load_nacc_data():
|
35 |
+
from data.dataset_csv import CSVDataset
|
36 |
+
dat = CSVDataset(
|
37 |
+
dat_file = "./data/test.csv",
|
38 |
+
cnf_file = "./data/input_meta_info.csv"
|
39 |
+
)
|
40 |
+
return dat
|
41 |
+
|
42 |
+
model = load_model()
|
43 |
+
dat_tst = load_nacc_data()
|
44 |
+
|
45 |
+
def predict_proba(data_dict):
|
46 |
+
pred_dict = model.predict_proba([data_dict])[1][0]
|
47 |
+
return pred_dict
|
48 |
+
|
49 |
+
# load NACC testing data
|
50 |
+
from data.dataset_csv import CSVDataset
|
51 |
+
dat_tst = CSVDataset(
|
52 |
+
dat_file = "./data/test.csv",
|
53 |
+
cnf_file = "./data/input_meta_info.csv"
|
54 |
+
)
|
55 |
+
|
56 |
+
# initialize session state for the text input if it's not already set
|
57 |
+
if 'input_text' not in st.session_state:
|
58 |
+
st.session_state.input_text = ""
|
59 |
+
|
60 |
+
# section 1
|
61 |
+
st.markdown("#### About")
|
62 |
+
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
|
63 |
+
|
64 |
+
# section 2
|
65 |
+
st.markdown("#### Demo")
|
66 |
+
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
|
67 |
+
|
68 |
+
# layout
|
69 |
+
layout_l, layout_r = st.columns([1, 1])
|
70 |
+
|
71 |
+
# create a form for user input
|
72 |
+
with layout_l:
|
73 |
+
with st.form("json_input_form"):
|
74 |
+
json_input = st.text_area(
|
75 |
+
"Please enter JSON-formatted input features:",
|
76 |
+
value = st.session_state.input_text,
|
77 |
+
height = 300
|
78 |
+
)
|
79 |
+
|
80 |
+
# create three columns
|
81 |
+
left_col, middle_col, right_col = st.columns([3, 4, 1])
|
82 |
+
|
83 |
+
with left_col:
|
84 |
+
sample_button = st.form_submit_button("Random case")
|
85 |
+
|
86 |
+
with right_col:
|
87 |
+
submit_button = st.form_submit_button("Predict")
|
88 |
+
|
89 |
+
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
|
90 |
+
nacc_mapping = pickle.load(file)
|
91 |
+
|
92 |
+
def convert_dictionary(original_dict, mappings):
|
93 |
+
transformed_dict = {}
|
94 |
+
|
95 |
+
for key, value in original_dict.items():
|
96 |
+
if key in mappings:
|
97 |
+
new_key, transform_map = mappings[key]
|
98 |
+
|
99 |
+
# If the value needs to be transformed
|
100 |
+
if value in transform_map:
|
101 |
+
transformed_value = transform_map[value]
|
102 |
+
else:
|
103 |
+
transformed_value = value # Keep the original value if no transformation is needed
|
104 |
+
|
105 |
+
transformed_dict[new_key] = transformed_value
|
106 |
+
|
107 |
+
return transformed_dict
|
108 |
+
|
109 |
+
if sample_button:
|
110 |
+
idx = random.randint(0, len(dat_tst) - 1)
|
111 |
+
random_case = dat_tst[idx][0]
|
112 |
+
st.session_state.input_text = json.dumps(random_case, indent=2)
|
113 |
+
|
114 |
+
# reset input text after form processing to show updated text in the input box
|
115 |
+
if 'input_text' in st.session_state:
|
116 |
+
st.experimental_rerun()
|
117 |
+
|
118 |
+
elif submit_button:
|
119 |
+
try:
|
120 |
+
# Parse the JSON input into a Python dictionary
|
121 |
+
data_dict = json.loads(json_input)
|
122 |
+
data_dict = convert_dictionary(data_dict, nacc_mapping)
|
123 |
+
# print(data_dict)
|
124 |
+
pred_dict = predict_proba(data_dict)
|
125 |
+
with layout_r:
|
126 |
+
st.write("Predicted probabilities:")
|
127 |
+
st.code(json.dumps(pred_dict, indent=2))
|
128 |
+
except json.JSONDecodeError as e:
|
129 |
+
# Handle JSON parsing errors
|
130 |
+
st.error(f"An error occurred: {e}")
|
131 |
+
|
132 |
+
# section 3
|
133 |
+
st.markdown("#### Feature Table")
|
134 |
+
df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
|
135 |
+
st.table(df_input_meta_info)
|