update
Browse files
app.py
CHANGED
@@ -25,11 +25,12 @@ def load_model():
|
|
25 |
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
26 |
except:
|
27 |
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
|
28 |
-
|
29 |
-
|
30 |
return model
|
31 |
|
32 |
model = load_model()
|
|
|
33 |
|
34 |
def predict_proba(data_dict):
|
35 |
pred_dict = model.predict_proba([data_dict])[1][0]
|
@@ -103,7 +104,6 @@ def create_input(i):
|
|
103 |
st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
|
104 |
else:
|
105 |
values = {int(k): v for k, v in values.items()}
|
106 |
-
reverse_mapping = {v: k for k, v in values.items()}
|
107 |
if default_value in values:
|
108 |
default_index = list(values.keys()).index(default_value)
|
109 |
else:
|
@@ -166,5 +166,12 @@ if predict_button:
|
|
166 |
data_dict = convert_dictionary(data_dict, nacc_mapping)
|
167 |
pred_dict = predict_proba(data_dict)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
st.write("Predicted probabilities:")
|
170 |
st.code(json.dumps(pred_dict, indent=2))
|
|
|
25 |
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
26 |
except:
|
27 |
# ckpt_path = '../adrd_tool_copied_from_sahana/dev/ckpt/ckpt_swinunetr_stripped_MNI.pt'
|
28 |
+
ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_swinunetr_stripped_MNI.pt'
|
29 |
+
model = adrd.model.ADRDModel.from_ckpt(ckpt_path, device='cpu')
|
30 |
return model
|
31 |
|
32 |
model = load_model()
|
33 |
+
print(dir(model))
|
34 |
|
35 |
def predict_proba(data_dict):
|
36 |
pred_dict = model.predict_proba([data_dict])[1][0]
|
|
|
104 |
st.number_input(description, key=name, min_value=min_value, value=default_value, placeholder=values['range'])
|
105 |
else:
|
106 |
values = {int(k): v for k, v in values.items()}
|
|
|
107 |
if default_value in values:
|
108 |
default_index = list(values.keys()).index(default_value)
|
109 |
else:
|
|
|
166 |
data_dict = convert_dictionary(data_dict, nacc_mapping)
|
167 |
pred_dict = predict_proba(data_dict)
|
168 |
|
169 |
+
# change key name
|
170 |
+
# key_mappings = {
|
171 |
+
# 'NC': 'Normal cognition',
|
172 |
+
# 'MCI':
|
173 |
+
|
174 |
+
# }
|
175 |
+
|
176 |
st.write("Predicted probabilities:")
|
177 |
st.code(json.dumps(pred_dict, indent=2))
|