|
import streamlit as st |
|
import json |
|
import random |
|
import pandas as pd |
|
import pickle |
|
import json |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.markdown("#### About") |
|
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.") |
|
st.markdown("Links:\n* Paper: [https://www.nature.com/articles/s41591-024-03118-z](https://www.nature.com/articles/s41591-024-03118-z)\n* GitHub: [https://github.com/vkola-lab/nmed2024](https://github.com/vkola-lab/nmed2024)\n* Our lab: [https://vkola-lab.github.io/](https://vkola-lab.github.io/)") |
|
|
|
|
|
st.markdown("#### Demo") |
|
st.markdown("This Hugging Face Space is published for demonstration purposes. Users can input over 300 clinical entries to assess the etiologies contributing to cognitive impairment. However, due to the computational power limitations of the Hugging Face free tier, imaging features and Shapley values analysis are not supported. For the full implementation, please refer to our GitHub repository.") |
|
st.markdown("To use the demo:\n* Provide input features in the form below. Feature missing is allowed.\n* Click the \"**RANDOM EXAMPLE**\" button to populate the form with a randomly selected datapoint.\n* Use the \"**PREDICT**\" button to submit all input features for assessment, then the predictions will be posted in a table.") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
import adrd |
|
try: |
|
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt' |
|
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') |
|
except: |
|
|
|
ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_swinunetr_stripped_MNI.pt' |
|
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu') |
|
return model |
|
|
|
model = load_model() |
|
|
|
def predict_proba(data_dict): |
|
pred_dict = model.predict_proba([data_dict])[1][0] |
|
return pred_dict |
|
|
|
|
|
file_path = './data/input_meta_info.csv' |
|
input_meta_info = pd.read_csv(file_path) |
|
|
|
|
|
from data.dataset_csv import CSVDataset |
|
dat_tst = CSVDataset( |
|
dat_file = "./data/test_public.csv", |
|
cnf_file = "./data/input_meta_info.csv" |
|
) |
|
|
|
def get_random_example(): |
|
idx = random.randint(0, len(dat_tst) - 1) |
|
random_case = dat_tst[idx][0] |
|
return random_case |
|
|
|
|
|
if 'random_example' not in st.session_state: |
|
st.session_state.random_example = None |
|
|
|
st.markdown('---') |
|
cols = st.columns(3) |
|
with cols[1]: |
|
random_example_button = st.button("RANDOM EXAMPLE", use_container_width=True) |
|
if random_example_button: |
|
st.session_state.random_example = get_random_example() |
|
st.rerun() |
|
|
|
random_example = st.session_state.random_example |
|
|
|
def create_input(df, i): |
|
row = df.iloc[i] |
|
name = row['Name'] |
|
description = row['Description'] |
|
|
|
|
|
values = row['Values'] |
|
values = values.replace('\'', '\"') |
|
values = values.replace('\"0\": nan, ', '') |
|
values = json.loads(values) |
|
|
|
for k, v in list(values.items()): |
|
if v == 'Unknown': |
|
values.pop(k) |
|
elif k in ('9', '99', '999'): |
|
values.pop(k) |
|
|
|
|
|
default_value = random_example[name] if random_example and name in random_example else None |
|
if type(default_value) is float: |
|
default_value = int(default_value) |
|
|
|
|
|
if 'range' in values: |
|
if ' - ' in values['range']: |
|
min_value, max_value = map(float, values['range'].split(' - ')) |
|
min_value, max_value = int(min_value), int(max_value) |
|
|
|
if default_value is not None: |
|
if default_value > max_value or default_value < min_value: |
|
default_value = None |
|
|
|
st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range']) |
|
else: |
|
min_value = int(values['range'].replace('>= ', '')) |
|
if default_value is not None: |
|
if default_value < min_value or default_value == 8888: |
|
default_value = None |
|
|
|
st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range']) |
|
else: |
|
values = {int(k): v for k, v in values.items()} |
|
if default_value in values: |
|
default_index = list(values.keys()).index(default_value) |
|
else: |
|
default_index = None |
|
|
|
st.selectbox( |
|
description, |
|
options = values.keys(), |
|
key = name, |
|
index = default_index, |
|
format_func=lambda x: values[x] |
|
) |
|
|
|
|
|
with st.form("dynamic_form"): |
|
sections = input_meta_info['Section'].unique() |
|
for section in sections: |
|
|
|
with st.container(): |
|
st.markdown(f"##### {section}") |
|
sub_df = input_meta_info[input_meta_info['Section'] == section] |
|
|
|
cols = st.columns(3) |
|
with cols[0]: |
|
for i in range(0, len(sub_df), 3): |
|
create_input(sub_df, i) |
|
with cols[1]: |
|
for i in range(1, len(sub_df), 3): |
|
create_input(sub_df, i) |
|
with cols[2]: |
|
for i in range(2, len(sub_df), 3): |
|
create_input(sub_df, i) |
|
|
|
|
|
st.markdown("---") |
|
|
|
cols = st.columns(3) |
|
with cols[1]: |
|
predict_button = st.form_submit_button("PREDICT", use_container_width=True, type='primary') |
|
|
|
|
|
with open('./data/nacc_variable_mappings.pkl', 'rb') as file: |
|
nacc_mapping = pickle.load(file) |
|
|
|
def convert_dictionary(original_dict, mappings): |
|
transformed_dict = {} |
|
|
|
for key, value in original_dict.items(): |
|
if key in mappings: |
|
new_key, transform_map = mappings[key] |
|
|
|
|
|
if value in transform_map: |
|
transformed_value = transform_map[value] |
|
else: |
|
transformed_value = value |
|
|
|
transformed_dict[new_key] = transformed_value |
|
|
|
return transformed_dict |
|
|
|
if predict_button: |
|
|
|
names = input_meta_info['Name'].tolist() |
|
data_dict = {} |
|
for name in names: |
|
data_dict[name] = st.session_state[name] |
|
|
|
|
|
data_dict = convert_dictionary(data_dict, nacc_mapping) |
|
pred_dict = predict_proba(data_dict) |
|
|
|
|
|
key_mappings = { |
|
'NC': 'Normal cognition', |
|
'MCI': 'Mild cognitive impairment', |
|
'DE': 'Dementia', |
|
'AD': 'Alzheimer\'s disease', |
|
'LBD': 'Lewy bodies and Parkinson\'s disease', |
|
'VD': 'Vascular brain injury or vascular dementia including stroke', |
|
'PRD': 'Prion disease including Creutzfeldt-Jakob disease', |
|
'FTD': 'Frontotemporal lobar degeneration', |
|
'NPH': 'Normal pressure hydrocephalus', |
|
'SEF': 'Systemic and external factors', |
|
'PSY': 'Psychiatric diseases', |
|
'TBI': 'Traumatic brain injury', |
|
'ODE': 'Other causes which include neoplasms, multiple systems atrophy, essential tremor, Huntington\'s disease, Down syndrome, and seizures' |
|
} |
|
pred_dict = {key_mappings[k]: f"{v * 100:.2f}%" for k, v in pred_dict.items()} |
|
|
|
df = pd.DataFrame(list(pred_dict.items()), columns=['Label', 'Predicted probability']) |
|
st.table(df) |