xf3227 commited on
Commit
2730dad
·
1 Parent(s): d217981
Files changed (1) hide show
  1. app.py +53 -21
app.py CHANGED
@@ -2,12 +2,29 @@ import streamlit as st
2
  import json
3
  import random
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @st.cache_resource
6
  def load_model():
7
  import adrd
8
- ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
9
- # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
10
- model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
 
 
 
11
  return model
12
 
13
  @st.cache_resource
@@ -37,28 +54,39 @@ dat_tst = CSVDataset(
37
  if 'input_text' not in st.session_state:
38
  st.session_state.input_text = ""
39
 
40
- # Create a form for user input
41
- with st.form("json_input_form"):
42
- st.write("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"Random NACC Example\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"Predict\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
43
- json_input = st.text_area(
44
- "Please enter JSON-formatted input features:",
45
- value = st.session_state.input_text,
46
- height = 250
47
- )
48
 
49
- # create three columns
50
- left_col, middle_col, right_col = st.columns([3, 4, 1])
51
 
52
- with left_col:
53
- sample_button = st.form_submit_button("Random NACC Case")
 
 
 
 
 
 
54
 
55
- with right_col:
56
- submit_button = st.form_submit_button("Predict")
57
 
 
 
 
 
 
 
58
  if sample_button:
59
  idx = random.randint(0, len(dat_tst) - 1)
60
- example = dat_tst[idx][0]
61
- st.session_state.input_text = json.dumps(example)
62
 
63
  # reset input text after form processing to show updated text in the input box
64
  if 'input_text' in st.session_state:
@@ -69,8 +97,12 @@ elif submit_button:
69
  # Parse the JSON input into a Python dictionary
70
  data_dict = json.loads(json_input)
71
  pred_dict = predict_proba(data_dict)
72
- st.write("Predicted probabilities:")
73
- st.json(pred_dict)
 
74
  except json.JSONDecodeError as e:
75
  # Handle JSON parsing errors
76
  st.error(f"An error occurred: {e}")
 
 
 
 
2
  import json
3
  import random
4
 
5
+ # set page configuration to wide mode
6
+ st.set_page_config(layout="wide")
7
+
8
+ st.markdown("""
9
+ <style>
10
+ .bounding-box {
11
+ border: 2px solid #4CAF50; # Green border
12
+ border-radius: 5px; # Rounded corners
13
+ padding: 10px; # Padding inside the box
14
+ margin: 10px; # Space outside the box
15
+ }
16
+ </style>
17
+ """, unsafe_allow_html=True)
18
+
19
  @st.cache_resource
20
  def load_model():
21
  import adrd
22
+ try:
23
+ ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
24
+ model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
25
+ except:
26
+ ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_densenet_emb_encoder_2_AUPR.pt'
27
+ model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
28
  return model
29
 
30
  @st.cache_resource
 
54
  if 'input_text' not in st.session_state:
55
  st.session_state.input_text = ""
56
 
57
+ # section 1
58
+ st.markdown("#### About ADRD")
59
+ st.markdown("Differential diagnosis of dementia remains a challenge in neurology due to symptom overlap across etiologies, yet it is crucial for formulating early, personalized management strategies. Here, we present an AI model that harnesses a broad array of data, including demographics, individual and family medical history, medication use, neuropsychological assessments, functional evaluations, and multimodal neuroimaging, to identify the etiologies contributing to dementia in individuals.")
60
+
61
+ # section 2
62
+ st.markdown("#### Demo")
63
+ st.markdown("Please enter the input features in the textbox below, formatted as a JSON dictionary. Click the \"**Random NACC Case**\" button to populate the textbox with a randomly selected case from the NACC testing dataset. Use the \"**Predict**\" button to submit your input to the model, which will then provide probability predictions for mental status and all 10 etiologies.")
 
64
 
65
+ # layout
66
+ layout_l, layout_r = st.columns([1, 1])
67
 
68
+ # create a form for user input
69
+ with layout_l:
70
+ with st.form("json_input_form"):
71
+ json_input = st.text_area(
72
+ "Please enter JSON-formatted input features:",
73
+ value = st.session_state.input_text,
74
+ height = 250
75
+ )
76
 
77
+ # create three columns
78
+ left_col, middle_col, right_col = st.columns([3, 4, 1])
79
 
80
+ with left_col:
81
+ sample_button = st.form_submit_button("Random NACC Case")
82
+
83
+ with right_col:
84
+ submit_button = st.form_submit_button("Predict")
85
+
86
  if sample_button:
87
  idx = random.randint(0, len(dat_tst) - 1)
88
+ random_case = dat_tst[idx][0]
89
+ st.session_state.input_text = json.dumps(random_case, indent=2)
90
 
91
  # reset input text after form processing to show updated text in the input box
92
  if 'input_text' in st.session_state:
 
97
  # Parse the JSON input into a Python dictionary
98
  data_dict = json.loads(json_input)
99
  pred_dict = predict_proba(data_dict)
100
+ with layout_r:
101
+ st.write("Predicted probabilities:")
102
+ st.json(pred_dict)
103
  except json.JSONDecodeError as e:
104
  # Handle JSON parsing errors
105
  st.error(f"An error occurred: {e}")
106
+
107
+ # section 3
108
+