File size: 8,144 Bytes
6397272
fdab7a6
ec804d3
8ea5177
30d58e8
5c24fb9
6397272
2730dad
 
 
5c24fb9
 
 
0932d80
5c24fb9
 
 
9284e05
b6c53b2
2730dad
5c24fb9
3e4df7c
de23f75
 
2730dad
30d58e8
2730dad
 
5c24fb9
fc4b558
 
de23f75
 
3e4df7c
 
de23f75
 
 
fdab7a6
5c24fb9
 
 
 
ec804d3
 
 
5e35c9f
30d58e8
ec804d3
 
5c24fb9
 
 
 
 
 
 
 
 
867e110
9175574
 
 
 
5c24fb9
 
 
 
 
9175574
 
5c24fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2730dad
ec804d3
5c24fb9
 
9175574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec804d3
5c24fb9
 
9175574
5c24fb9
 
30d58e8
 
5c24fb9
30d58e8
 
 
 
 
 
2730dad
30d58e8
 
 
 
 
 
 
 
 
fdab7a6
5c24fb9
 
 
 
 
 
ec804d3
5c24fb9
 
 
 
d987fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51c85c
d987fcd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import streamlit as st
import json
import random
import pandas as pd
import pickle
import json

# set page configuration to wide mode
st.set_page_config(layout="wide")

# section 1
st.markdown("#### About")
st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
st.markdown("Links:\n* Paper: [https://www.nature.com/articles/s41591-024-03118-z](https://www.nature.com/articles/s41591-024-03118-z)\n* GitHub: [https://github.com/vkola-lab/nmed2024](https://github.com/vkola-lab/nmed2024)\n* Our lab: [https://vkola-lab.github.io/](https://vkola-lab.github.io/)")

# section 2
st.markdown("#### Demo")
st.markdown("This Hugging Face Space is published for demonstration purposes. Users can input over 300 clinical entries to assess the etiologies contributing to cognitive impairment. However, due to the computational power limitations of the Hugging Face free tier, imaging features and Shapley values analysis are not supported. For the full implementation, please refer to our GitHub repository.")
st.markdown("To use the demo:\n* Provide input features in the form below. Feature missing is allowed.\n* Click the \"**RANDOM EXAMPLE**\" button to populate the form with a randomly selected datapoint.\n* Use the \"**PREDICT**\" button to submit all input features for assessment, then the predictions will be posted in a table.")

# load model
@st.cache_resource
def load_model():
    import adrd
    try:   
        ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
        model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
    except:
        # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
        ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_swinunetr_stripped_MNI.pt'
        model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
    return model

model = load_model()

def predict_proba(data_dict):
    pred_dict = model.predict_proba([data_dict])[1][0]
    return pred_dict

# load meta data csv
file_path = './data/input_meta_info.csv'
input_meta_info = pd.read_csv(file_path)

# load NACC testing data
from data.dataset_csv import CSVDataset
dat_tst = CSVDataset(
    dat_file = "./data/test_public.csv", 
    cnf_file = "./data/input_meta_info.csv"
)

def get_random_example():
    idx = random.randint(0, len(dat_tst) - 1)
    random_case = dat_tst[idx][0]
    return random_case

# Get random example features if the button is clicked
if 'random_example' not in st.session_state:
    st.session_state.random_example = None

st.markdown('---')
cols = st.columns(3)
with cols[1]:
    random_example_button = st.button("RANDOM EXAMPLE", use_container_width=True)
if random_example_button:
    st.session_state.random_example = get_random_example()
    st.rerun()

random_example = st.session_state.random_example

def create_input(df, i):
    row = df.iloc[i]
    name = row['Name']
    description = row['Description']

    # dirty work, inspect keys and values
    values = row['Values']
    values = values.replace('\'', '\"')
    values = values.replace('\"0\": nan, ', '')
    values = json.loads(values)

    for k, v in list(values.items()):
        if v == 'Unknown':
            values.pop(k)
        elif k in ('9', '99', '999'):
            values.pop(k)

        # get default value from random example if available
    default_value = random_example[name] if random_example and name in random_example else None
    if type(default_value) is float:
        default_value = int(default_value)

    # Determine the type of widget based on values
    if 'range' in values:
        if ' - ' in values['range']:
            min_value, max_value = map(float, values['range'].split(' - '))
            min_value, max_value = int(min_value), int(max_value)

            if default_value is not None:
                if default_value > max_value or default_value < min_value:
                    default_value = None
                
            st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range'])
        else:
            min_value = int(values['range'].replace('>= ', ''))
            if default_value is not None:
                if default_value < min_value or default_value == 8888:
                    default_value = None

            st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
    else:
        values = {int(k): v for k, v in values.items()}
        if default_value in values:
            default_index = list(values.keys()).index(default_value)
        else:
            default_index = None

        st.selectbox(
            description, 
            options = values.keys(), 
            key = name, 
            index = default_index,
            format_func=lambda x: values[x]
        )

# create form
with st.form("dynamic_form"):
    sections = input_meta_info['Section'].unique()
    for section in sections:

        with st.container():
            st.markdown(f"##### {section}")
            sub_df = input_meta_info[input_meta_info['Section'] == section]

            cols = st.columns(3)
            with cols[0]:
                for i in range(0, len(sub_df), 3):
                    create_input(sub_df, i)
            with cols[1]:
                for i in range(1, len(sub_df), 3):
                    create_input(sub_df, i)
            with cols[2]:
                for i in range(2, len(sub_df), 3):
                    create_input(sub_df, i)    

        # seperate line
        st.markdown("---")

    cols = st.columns(3)
    with cols[1]:
        predict_button = st.form_submit_button("PREDICT", use_container_width=True, type='primary')

# load mapping
with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
    nacc_mapping = pickle.load(file)

def convert_dictionary(original_dict, mappings):
    transformed_dict = {}
    
    for key, value in original_dict.items():
        if key in mappings:
            new_key, transform_map = mappings[key]
            
            # If the value needs to be transformed
            if value in transform_map:
                transformed_value = transform_map[value]
            else:
                transformed_value = value  # Keep the original value if no transformation is needed
            
            transformed_dict[new_key] = transformed_value
    
    return transformed_dict

if predict_button:
    # get form input
    names = input_meta_info['Name'].tolist()
    data_dict = {}
    for name in names:
        data_dict[name] = st.session_state[name]
    
    # convert
    data_dict = convert_dictionary(data_dict, nacc_mapping)
    pred_dict = predict_proba(data_dict)

    # change key name and value representations
    key_mappings = {
        'NC': 'Normal cognition',
        'MCI': 'Mild cognitive impairment',
        'DE': 'Dementia',
        'AD': 'Alzheimer\'s disease',
        'LBD': 'Lewy bodies and Parkinson\'s disease',
        'VD': 'Vascular brain injury or vascular dementia including stroke',
        'PRD': 'Prion disease including Creutzfeldt-Jakob disease',
        'FTD': 'Frontotemporal lobar degeneration',
        'NPH': 'Normal pressure hydrocephalus',
        'SEF': 'Systemic and external factors',
        'PSY': 'Psychiatric diseases',
        'TBI': 'Traumatic brain injury',
        'ODE': 'Other causes which include neoplasms, multiple systems atrophy, essential tremor, Huntington\'s disease, Down syndrome, and seizures'
    }
    pred_dict = {key_mappings[k]: f"{v * 100:.2f}%" for k, v in pred_dict.items()}

    df = pd.DataFrame(list(pred_dict.items()), columns=['Label', 'Predicted probability'])
    st.table(df)