xf3227 commited on
Commit
de23f75
·
1 Parent(s): b7da9fe
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -1,9 +1,17 @@
1
  import streamlit as st
2
  import json
3
 
4
- import adrd
5
- ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
6
- mdl = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
 
 
 
 
 
 
 
 
7
 
8
  # Create a form for user input
9
  with st.form("json_input_form"):
@@ -16,7 +24,7 @@ if submit_button:
16
  try:
17
  # Parse the JSON input into a Python dictionary
18
  data_dict = json.loads(json_input)
19
- pred_dict = mdl.predict_proba([data_dict])[1][0]
20
  st.write("Predicted probabilities:")
21
  st.json(pred_dict)
22
  except json.JSONDecodeError as e:
 
1
  import streamlit as st
2
  import json
3
 
4
+ @st.cache(allow_output_mutation=True)
5
+ def load_model():
6
+ import adrd
7
+ ckpt_path = './ckpt_densenet_emb_encoder_2_AUPR.pt'
8
+ model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
9
+ return model
10
+
11
+ def predict_proba(data_dict):
12
+ model = load_model()
13
+ pred_dict = model.predict_proba([data_dict])[1][0]
14
+ return pred_dict
15
 
16
  # Create a form for user input
17
  with st.form("json_input_form"):
 
24
  try:
25
  # Parse the JSON input into a Python dictionary
26
  data_dict = json.loads(json_input)
27
+ pred_dict = predict_proba(data_dict)
28
  st.write("Predicted probabilities:")
29
  st.json(pred_dict)
30
  except json.JSONDecodeError as e: