xf3227 commited on
Commit
fc4b558
·
1 Parent(s): 5c24fb9
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -25,11 +25,12 @@ def load_model():
25
  model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
26
  except:
27
  # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
28
- # model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
29
- return None
30
  return model
31
 
32
  model = load_model()
 
33
 
34
  def predict_proba(data_dict):
35
  pred_dict = model.predict_proba([data_dict])[1][0]
@@ -103,7 +104,6 @@ def create_input(i):
103
  st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
104
  else:
105
  values = {int(k): v for k, v in values.items()}
106
- reverse_mapping = {v: k for k, v in values.items()}
107
  if default_value in values:
108
  default_index = list(values.keys()).index(default_value)
109
  else:
@@ -166,5 +166,12 @@ if predict_button:
166
  data_dict = convert_dictionary(data_dict, nacc_mapping)
167
  pred_dict = predict_proba(data_dict)
168
 
 
 
 
 
 
 
 
169
  st.write("Predicted probabilities:")
170
  st.code(json.dumps(pred_dict, indent=2))
 
25
  model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
26
  except:
27
  # ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
28
+ ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_swinunetr_stripped_MNI.pt'
29
+ model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
30
  return model
31
 
32
  model = load_model()
33
+ print(dir(model))
34
 
35
  def predict_proba(data_dict):
36
  pred_dict = model.predict_proba([data_dict])[1][0]
 
104
  st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
105
  else:
106
  values = {int(k): v for k, v in values.items()}
 
107
  if default_value in values:
108
  default_index = list(values.keys()).index(default_value)
109
  else:
 
166
  data_dict = convert_dictionary(data_dict, nacc_mapping)
167
  pred_dict = predict_proba(data_dict)
168
 
169
+ # change key name
170
+ # key_mappings = {
171
+ # 'NC': 'Normal cognition',
172
+ # 'MCI':
173
+
174
+ # }
175
+
176
  st.write("Predicted probabilities:")
177
  st.code(json.dumps(pred_dict, indent=2))