File size: 8,144 Bytes
6397272 fdab7a6 ec804d3 8ea5177 30d58e8 5c24fb9 6397272 2730dad 5c24fb9 0932d80 5c24fb9 9284e05 b6c53b2 2730dad 5c24fb9 3e4df7c de23f75 2730dad 30d58e8 2730dad 5c24fb9 fc4b558 de23f75 3e4df7c de23f75 fdab7a6 5c24fb9 ec804d3 5e35c9f 30d58e8 ec804d3 5c24fb9 867e110 9175574 5c24fb9 9175574 5c24fb9 2730dad ec804d3 5c24fb9 9175574 ec804d3 5c24fb9 9175574 5c24fb9 30d58e8 5c24fb9 30d58e8 2730dad 30d58e8 fdab7a6 5c24fb9 ec804d3 5c24fb9 d987fcd e51c85c d987fcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import streamlit as st
import json
import random
import pandas as pd
import pickle
import json
# set page configuration to wide mode
st.set_page_config(layout="wide")
# section 1
st.markdown("#### About")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
st.markdown("Links:\n* Paper: [https://www.nature.com/articles/s41591-024-03118-z](https://www.nature.com/articles/s41591-024-03118-z)\n* GitHub: [https://github.com/vkola-lab/nmed2024](https://github.com/vkola-lab/nmed2024)\n* Our lab: [https://vkola-lab.github.io/](https://vkola-lab.github.io/)")
# section 2
st.markdown("#### Demo")
st.markdown("This Hugging Face Space is published for demonstration purposes. Users can input over 300 clinical entries to assess the etiologies contributing to cognitive impairment. However, due to the computational power limitations of the Hugging Face free tier, imaging features and Shapley values analysis are not supported. For the full implementation, please refer to our GitHub repository.")
st.markdown("To use the demo:\n* Provide input features in the form below. Feature missing is allowed.\n* Click the \"**RANDOM EXAMPLE**\" button to populate the form with a randomly selected datapoint.\n* Use the \"**PREDICT**\" button to submit all input features for assessment, then the predictions will be posted in a table.")
# load model
@st.cache_resource
def load_model():
import adrd
try:
ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
except:
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_swinunetr_stripped_MNI.pt'
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
return model
model = load_model()
def predict_proba(data_dict):
pred_dict = model.predict_proba([data_dict])[1][0]
return pred_dict
# load meta data csv
file_path = './data/input_meta_info.csv'
input_meta_info = pd.read_csv(file_path)
# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
dat_file = "./data/test_public.csv",
cnf_file = "./data/input_meta_info.csv"
)
def get_random_example():
idx = random.randint(0, len(dat_tst) - 1)
random_case = dat_tst[idx][0]
return random_case
# Get random example features if the button is clicked
if 'random_example' not in st.session_state:
st.session_state.random_example = None
st.markdown('---')
cols = st.columns(3)
with cols[1]:
random_example_button = st.button("RANDOM EXAMPLE", use_container_width=True)
if random_example_button:
st.session_state.random_example = get_random_example()
st.rerun()
random_example = st.session_state.random_example
def create_input(df, i):
row = df.iloc[i]
name = row['Name']
description = row['Description']
# dirty work, inspect keys and values
values = row['Values']
values = values.replace('\'', '\"')
values = values.replace('\"0\": nan, ', '')
values = json.loads(values)
for k, v in list(values.items()):
if v == 'Unknown':
values.pop(k)
elif k in ('9', '99', '999'):
values.pop(k)
# get default value from random example if available
default_value = random_example[name] if random_example and name in random_example else None
if type(default_value) is float:
default_value = int(default_value)
# Determine the type of widget based on values
if 'range' in values:
if ' - ' in values['range']:
min_value, max_value = map(float, values['range'].split(' - '))
min_value, max_value = int(min_value), int(max_value)
if default_value is not None:
if default_value > max_value or default_value < min_value:
default_value = None
st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range'])
else:
min_value = int(values['range'].replace('>= ', ''))
if default_value is not None:
if default_value < min_value or default_value == 8888:
default_value = None
st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
else:
values = {int(k): v for k, v in values.items()}
if default_value in values:
default_index = list(values.keys()).index(default_value)
else:
default_index = None
st.selectbox(
description,
options = values.keys(),
key = name,
index = default_index,
format_func=lambda x: values[x]
)
# create form
with st.form("dynamic_form"):
sections = input_meta_info['Section'].unique()
for section in sections:
with st.container():
st.markdown(f"##### {section}")
sub_df = input_meta_info[input_meta_info['Section'] == section]
cols = st.columns(3)
with cols[0]:
for i in range(0, len(sub_df), 3):
create_input(sub_df, i)
with cols[1]:
for i in range(1, len(sub_df), 3):
create_input(sub_df, i)
with cols[2]:
for i in range(2, len(sub_df), 3):
create_input(sub_df, i)
# seperate line
st.markdown("---")
cols = st.columns(3)
with cols[1]:
predict_button = st.form_submit_button("PREDICT", use_container_width=True, type='primary')
# load mapping
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
nacc_mapping = pickle.load(file)
def convert_dictionary(original_dict, mappings):
transformed_dict = {}
for key, value in original_dict.items():
if key in mappings:
new_key, transform_map = mappings[key]
# If the value needs to be transformed
if value in transform_map:
transformed_value = transform_map[value]
else:
transformed_value = value # Keep the original value if no transformation is needed
transformed_dict[new_key] = transformed_value
return transformed_dict
if predict_button:
# get form input
names = input_meta_info['Name'].tolist()
data_dict = {}
for name in names:
data_dict[name] = st.session_state[name]
# convert
data_dict = convert_dictionary(data_dict, nacc_mapping)
pred_dict = predict_proba(data_dict)
# change key name and value representations
key_mappings = {
'NC': 'Normal cognition',
'MCI': 'Mild cognitive impairment',
'DE': 'Dementia',
'AD': 'Alzheimer\'s disease',
'LBD': 'Lewy bodies and Parkinson\'s disease',
'VD': 'Vascular brain injury or vascular dementia including stroke',
'PRD': 'Prion disease including Creutzfeldt-Jakob disease',
'FTD': 'Frontotemporal lobar degeneration',
'NPH': 'Normal pressure hydrocephalus',
'SEF': 'Systemic and external factors',
'PSY': 'Psychiatric diseases',
'TBI': 'Traumatic brain injury',
'ODE': 'Other causes which include neoplasms, multiple systems atrophy, essential tremor, Huntington\'s disease, Down syndrome, and seizures'
}
pred_dict = {key_mappings[k]: f"{v * 100:.2f}%" for k, v in pred_dict.items()}
df = pd.DataFrame(list(pred_dict.items()), columns=['Label', 'Predicted probability'])
st.table(df) |