nmed2024 / app.py
xf3227's picture
ok
2730dad
raw
history blame
3.85 kB
import streamlit as st
import json
import random
# set page configuration to wide mode
st.set_page_config(layout="wide")
st.markdown("""
<style>
.bounding-box {
border: 2px solid #4CAF50; # Green border
border-radius: 5px; # Rounded corners
padding: 10px; # Padding inside the box
margin: 10px; # Space outside the box
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
import adrd
try:
ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
except:
ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
return model
@st.cache_resource
def load_nacc_data():
from data.dataset_csv import CSVDataset
dat = CSVDataset(
dat_file = "./data/nacc_test_with_np_cli.csv",
cnf_file = "./data/default_conf_new.toml"
)
return dat
model = load_model()
dat_tst = load_nacc_data()
def predict_proba(data_dict):
pred_dict = model.predict_proba([data_dict])[1][0]
return pred_dict
# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
dat_file = "./data/nacc_test_with_np_cli.csv",
cnf_file = "./data/default_conf_new.toml"
)
# initialize session state for the text input if it's not already set
if 'input_text' not in st.session_state:
st.session_state.input_text = ""
# section 1
st.markdown("#### About ADRD")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
# section 2
st.markdown("#### Demo")
st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random NACC Case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
# layout
layout_l, layout_r = st.columns([1, 1])
# create a form for user input
with layout_l:
with st.form("json_input_form"):
json_input = st.text_area(
"Please enter JSON-formatted input features:",
value = st.session_state.input_text,
height = 250
)
# create three columns
left_col, middle_col, right_col = st.columns([3, 4, 1])
with left_col:
sample_button = st.form_submit_button("Random NACC Case")
with right_col:
submit_button = st.form_submit_button("Predict")
if sample_button:
idx = random.randint(0, len(dat_tst) - 1)
random_case = dat_tst[idx][0]
st.session_state.input_text = json.dumps(random_case, indent=2)
# reset input text after form processing to show updated text in the input box
if 'input_text' in st.session_state:
st.experimental_rerun()
elif submit_button:
try:
# Parse the JSON input into a Python dictionary
data_dict = json.loads(json_input)
pred_dict = predict_proba(data_dict)
with layout_r:
st.write("Predicted probabilities:")
st.json(pred_dict)
except json.JSONDecodeError as e:
# Handle JSON parsing errors
st.error(f"An error occurred: {e}")
# section 3