xf3227 commited on
Commit
5c24fb9
·
1 Parent(s): 7231707
Files changed (2) hide show
  1. app.py +114 -78
  2. app_backup_1.py +135 -0
app.py CHANGED
@@ -3,21 +3,20 @@ import json
3
  import random
4
  import pandas as pd
5
  import pickle
 
6
 
7
  # set page configuration to wide mode
8
  st.set_page_config(layout="wide")
9
 
10
- st.markdown("""
11
- <style>
12
- .bounding-box {
13
- border: 2px solid #4CAF50; # Green border
14
- border-radius: 5px; # Rounded corners
15
- padding: 10px; # Padding inside the box
16
- margin: 10px; # Space outside the box
17
- }
18
- </style>
19
- """, unsafe_allow_html=True)
20
 
 
21
  @st.cache_resource
22
  def load_model():
23
  import adrd
@@ -25,26 +24,21 @@ def load_model():
25
  ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
26
  model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
27
  except:
28
- ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
29
- model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
 
30
  return model
31
 
32
- @st.cache_resource
33
- def load_nacc_data():
34
- from data.dataset_csv import CSVDataset
35
- dat = CSVDataset(
36
- dat_file = "./data/test.csv",
37
- cnf_file = "./data/input_meta_info.csv"
38
- )
39
- return dat
40
-
41
  model = load_model()
42
- dat_tst = load_nacc_data()
43
 
44
  def predict_proba(data_dict):
45
  pred_dict = model.predict_proba([data_dict])[1][0]
46
  return pred_dict
47
 
 
 
 
 
48
  # load NACC testing data
49
  from data.dataset_csv import CSVDataset
50
  dat_tst = CSVDataset(
@@ -52,42 +46,98 @@ dat_tst = CSVDataset(
52
  cnf_file = "./data/input_meta_info.csv"
53
  )
54
 
55
- # initialize session state for the text input if it's not already set
56
- if 'input_text' not in st.session_state:
57
- st.session_state.input_text = ""
58
-
59
- # section 1
60
- st.markdown("#### About")
61
- st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
62
-
63
- # section 2
64
- st.markdown("#### Demo")
65
- st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
66
-
67
- # layout
68
- layout_l, layout_r = st.columns([1, 1])
69
-
70
- # create a form for user input
71
- with layout_l:
72
- with st.form("json_input_form"):
73
- json_input = st.text_area(
74
- "Please enter JSON-formatted input features:",
75
- value = st.session_state.input_text,
76
- height = 300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
 
79
- # create three columns
80
- left_col, middle_col, right_col = st.columns([3, 4, 1])
 
81
 
82
- with left_col:
83
- sample_button = st.form_submit_button("Random case")
 
 
 
 
 
 
 
 
84
 
85
- with right_col:
86
- submit_button = st.form_submit_button("Predict")
87
-
88
  with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
89
  nacc_mapping = pickle.load(file)
90
-
91
  def convert_dictionary(original_dict, mappings):
92
  transformed_dict = {}
93
 
@@ -104,31 +154,17 @@ def convert_dictionary(original_dict, mappings):
104
  transformed_dict[new_key] = transformed_value
105
 
106
  return transformed_dict
107
-
108
- if sample_button:
109
- idx = random.randint(0, len(dat_tst) - 1)
110
- random_case = dat_tst[idx][0]
111
- st.session_state.input_text = json.dumps(random_case, indent=2)
112
 
113
- # reset input text after form processing to show updated text in the input box
114
- if 'input_text' in st.session_state:
115
- st.experimental_rerun()
 
 
 
116
 
117
- elif submit_button:
118
- try:
119
- # Parse the JSON input into a Python dictionary
120
- data_dict = json.loads(json_input)
121
- data_dict = convert_dictionary(data_dict, nacc_mapping)
122
- # print(data_dict)
123
- pred_dict = predict_proba(data_dict)
124
- with layout_r:
125
- st.write("Predicted probabilities:")
126
- st.code(json.dumps(pred_dict, indent=2))
127
- except json.JSONDecodeError as e:
128
- # Handle JSON parsing errors
129
- st.error(f"An error occurred: {e}")
130
-
131
- # section 3
132
- st.markdown("#### Feature Table")
133
- df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
134
- st.table(df_input_meta_info)
 
3
  import random
4
  import pandas as pd
5
  import pickle
6
+ import json
7
 
8
  # set page configuration to wide mode
9
  st.set_page_config(layout="wide")
10
 
11
+ # section 1
12
+ st.markdown("#### About")
13
+ st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
14
+
15
+ # section 2
16
+ st.markdown("#### Demo")
17
+ st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**RANDOM EXAMPLE**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**PREDICT**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
 
 
 
18
 
19
+ # load model
20
  @st.cache_resource
21
  def load_model():
22
  import adrd
 
24
  ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
25
  model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
26
  except:
27
+ # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
28
+ # model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
29
+ return None
30
  return model
31
 
 
 
 
 
 
 
 
 
 
32
  model = load_model()
 
33
 
34
  def predict_proba(data_dict):
35
  pred_dict = model.predict_proba([data_dict])[1][0]
36
  return pred_dict
37
 
38
+ # load meta data csv
39
+ file_path = './data/input_meta_info.csv'
40
+ input_meta_info = pd.read_csv(file_path)
41
+
42
  # load NACC testing data
43
  from data.dataset_csv import CSVDataset
44
  dat_tst = CSVDataset(
 
46
  cnf_file = "./data/input_meta_info.csv"
47
  )
48
 
49
+ def get_random_example():
50
+ idx = random.randint(0, len(dat_tst) - 1)
51
+ random_case = dat_tst[idx][0]
52
+ return random_case
53
+
54
+ # Get random example features if the button is clicked
55
+ if 'random_example' not in st.session_state:
56
+ st.session_state.random_example = None
57
+
58
+ if st.button("RANDOM EXAMPLE"):
59
+ st.session_state.random_example = get_random_example()
60
+ st.rerun()
61
+
62
+ random_example = st.session_state.random_example
63
+
64
+ def create_input(i):
65
+ row = input_meta_info.iloc[i]
66
+ name = row['Name']
67
+ description = row['Description']
68
+
69
+ # dirty work, inspect keys and values
70
+ values = row['Values']
71
+ values = values.replace('\'', '\"')
72
+ values = values.replace('\"0\": nan, ', '')
73
+ values = json.loads(values)
74
+
75
+ for k, v in list(values.items()):
76
+ if v == 'Unknown':
77
+ values.pop(k)
78
+ elif k in ('9', '99', '999'):
79
+ values.pop(k)
80
+
81
+ # get default value from random example if available
82
+ default_value = random_example[name] if random_example and name in random_example else None
83
+ if type(default_value) is float:
84
+ default_value = int(default_value)
85
+
86
+ # Determine the type of widget based on values
87
+ if 'range' in values:
88
+ if ' - ' in values['range']:
89
+ min_value, max_value = map(float, values['range'].split(' - '))
90
+ min_value, max_value = int(min_value), int(max_value)
91
+
92
+ if default_value is not None:
93
+ if default_value > max_value or default_value < min_value:
94
+ default_value = None
95
+
96
+ st.number_input(description, key=name, min_value=min_value, max_value=max_value, value=default_value, placeholder=values['range'])
97
+ else:
98
+ min_value = int(values['range'].replace('>= ', ''))
99
+ if default_value is not None:
100
+ if default_value < min_value or default_value == 8888:
101
+ default_value = None
102
+
103
+ st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
104
+ else:
105
+ values = {int(k): v for k, v in values.items()}
106
+ reverse_mapping = {v: k for k, v in values.items()}
107
+ if default_value in values:
108
+ default_index = list(values.keys()).index(default_value)
109
+ else:
110
+ default_index = None
111
+
112
+ st.selectbox(
113
+ description,
114
+ options = values.keys(),
115
+ key = name,
116
+ index = default_index,
117
+ format_func=lambda x: values[x]
118
  )
119
 
120
+ # create form
121
+ with st.form("dynamic_form"):
122
+ # random_example_button = st.form_submit_button("RANDOM EXAMPLE")
123
 
124
+ cols = st.columns(3)
125
+ with cols[0]:
126
+ for i in range(0, len(input_meta_info), 3):
127
+ create_input(i)
128
+ with cols[1]:
129
+ for i in range(1, len(input_meta_info), 3):
130
+ create_input(i)
131
+ with cols[2]:
132
+ for i in range(2, len(input_meta_info), 3):
133
+ create_input(i)
134
 
135
+ predict_button = st.form_submit_button("PREDICT")
136
+
137
+ # load mapping
138
  with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
139
  nacc_mapping = pickle.load(file)
140
+
141
  def convert_dictionary(original_dict, mappings):
142
  transformed_dict = {}
143
 
 
154
  transformed_dict[new_key] = transformed_value
155
 
156
  return transformed_dict
 
 
 
 
 
157
 
158
+ if predict_button:
159
+ # get form input
160
+ names = input_meta_info['Name'].tolist()
161
+ data_dict = {}
162
+ for name in names:
163
+ data_dict[name] = st.session_state[name]
164
 
165
+ # convert
166
+ data_dict = convert_dictionary(data_dict, nacc_mapping)
167
+ pred_dict = predict_proba(data_dict)
168
+
169
+ st.write("Predicted probabilities:")
170
+ st.code(json.dumps(pred_dict, indent=2))
 
 
 
 
 
 
 
 
 
 
 
 
app_backup_1.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import random
4
+ import pandas as pd
5
+ import pickle
6
+
7
+ # set page configuration to wide mode
8
+ st.set_page_config(layout="wide")
9
+
10
+ st.markdown("""
11
+ <style>
12
+ .bounding-box {
13
+ border: 2px solid #4CAF50; # Green border
14
+ border-radius: 5px; # Rounded corners
15
+ padding: 10px; # Padding inside the box
16
+ margin: 10px; # Space outside the box
17
+ }
18
+ </style>
19
+ """, unsafe_allow_html=True)
20
+
21
+ @st.cache_resource
22
+ def load_model():
23
+ import adrd
24
+ try:
25
+ ckpt_path = './ckpt_swinunetr_stripped_MNI.pt'
26
+ model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
27
+ except:
28
+ # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
29
+ # model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
30
+ return None
31
+ return model
32
+
33
+ @st.cache_resource
34
+ def load_nacc_data():
35
+ from data.dataset_csv import CSVDataset
36
+ dat = CSVDataset(
37
+ dat_file = "./data/test.csv",
38
+ cnf_file = "./data/input_meta_info.csv"
39
+ )
40
+ return dat
41
+
42
+ model = load_model()
43
+ dat_tst = load_nacc_data()
44
+
45
+ def predict_proba(data_dict):
46
+ pred_dict = model.predict_proba([data_dict])[1][0]
47
+ return pred_dict
48
+
49
+ # load NACC testing data
50
+ from data.dataset_csv import CSVDataset
51
+ dat_tst = CSVDataset(
52
+ dat_file = "./data/test.csv",
53
+ cnf_file = "./data/input_meta_info.csv"
54
+ )
55
+
56
+ # initialize session state for the text input if it's not already set
57
+ if 'input_text' not in st.session_state:
58
+ st.session_state.input_text = ""
59
+
60
+ # section 1
61
+ st.markdown("#### About")
62
+ st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
63
+
64
+ # section 2
65
+ st.markdown("#### Demo")
66
+ st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
67
+
68
+ # layout
69
+ layout_l, layout_r = st.columns([1, 1])
70
+
71
+ # create a form for user input
72
+ with layout_l:
73
+ with st.form("json_input_form"):
74
+ json_input = st.text_area(
75
+ "Please enter JSON-formatted input features:",
76
+ value = st.session_state.input_text,
77
+ height = 300
78
+ )
79
+
80
+ # create three columns
81
+ left_col, middle_col, right_col = st.columns([3, 4, 1])
82
+
83
+ with left_col:
84
+ sample_button = st.form_submit_button("Random case")
85
+
86
+ with right_col:
87
+ submit_button = st.form_submit_button("Predict")
88
+
89
+ with open('./data/nacc_variable_mappings.pkl', 'rb') as file:
90
+ nacc_mapping = pickle.load(file)
91
+
92
+ def convert_dictionary(original_dict, mappings):
93
+ transformed_dict = {}
94
+
95
+ for key, value in original_dict.items():
96
+ if key in mappings:
97
+ new_key, transform_map = mappings[key]
98
+
99
+ # If the value needs to be transformed
100
+ if value in transform_map:
101
+ transformed_value = transform_map[value]
102
+ else:
103
+ transformed_value = value # Keep the original value if no transformation is needed
104
+
105
+ transformed_dict[new_key] = transformed_value
106
+
107
+ return transformed_dict
108
+
109
+ if sample_button:
110
+ idx = random.randint(0, len(dat_tst) - 1)
111
+ random_case = dat_tst[idx][0]
112
+ st.session_state.input_text = json.dumps(random_case, indent=2)
113
+
114
+ # reset input text after form processing to show updated text in the input box
115
+ if 'input_text' in st.session_state:
116
+ st.experimental_rerun()
117
+
118
+ elif submit_button:
119
+ try:
120
+ # Parse the JSON input into a Python dictionary
121
+ data_dict = json.loads(json_input)
122
+ data_dict = convert_dictionary(data_dict, nacc_mapping)
123
+ # print(data_dict)
124
+ pred_dict = predict_proba(data_dict)
125
+ with layout_r:
126
+ st.write("Predicted probabilities:")
127
+ st.code(json.dumps(pred_dict, indent=2))
128
+ except json.JSONDecodeError as e:
129
+ # Handle JSON parsing errors
130
+ st.error(f"An error occurred: {e}")
131
+
132
+ # section 3
133
+ st.markdown("#### Feature Table")
134
+ df_input_meta_info = pd.read_csv('./data/input_meta_info.csv')
135
+ st.table(df_input_meta_info)