Spaces:
Running
Running
Commit
·
48009b1
1
Parent(s):
b88647e
Update predict plot with CIPHER ASAP demo changes
Browse files- pages/predict.py +29 -9
pages/predict.py
CHANGED
@@ -186,33 +186,53 @@ with st.spinner('Computing predictions...'):
|
|
186 |
|
187 |
if show_val:
|
188 |
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
|
189 |
-
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].
|
|
|
190 |
else:
|
191 |
-
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# Show filtered nodes
|
194 |
if target_node_type not in ['disease', 'anatomy']:
|
195 |
-
st.dataframe(
|
196 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
197 |
help = "Click to visit external database.",
|
198 |
display_text = display_database)})
|
199 |
else:
|
200 |
-
st.dataframe(
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
# Plot rank vs. score using matplotlib
|
203 |
-
st.markdown("**Rank vs. Score**")
|
204 |
fig, ax = plt.subplots(figsize = (10, 6))
|
205 |
-
ax.plot(display_data['Rank'], display_data['Score'])
|
206 |
ax.set_xlabel('Rank', fontsize = 12)
|
207 |
ax.set_ylabel('Score', fontsize = 12)
|
208 |
ax.set_xlim(1, display_data['Rank'].max())
|
209 |
|
|
|
|
|
|
|
|
|
210 |
# Add vertical line for selected nodes
|
211 |
for i, node in selected_display_data.iterrows():
|
212 |
-
ax.
|
213 |
-
ax.
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
-
# Show plot
|
216 |
st.pyplot(fig)
|
217 |
|
218 |
|
|
|
186 |
|
187 |
if show_val:
|
188 |
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
|
189 |
+
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
|
190 |
+
selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
|
191 |
else:
|
192 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
|
193 |
+
selected_display_data = selected_display_data.reset_index(drop=True)
|
194 |
+
|
195 |
+
st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
|
196 |
+
selected_display_data_with_rank = selected_display_data.copy()
|
197 |
+
selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
|
198 |
|
199 |
# Show filtered nodes
|
200 |
if target_node_type not in ['disease', 'anatomy']:
|
201 |
+
st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
|
202 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
203 |
help = "Click to visit external database.",
|
204 |
display_text = display_database)})
|
205 |
else:
|
206 |
+
st.dataframe(selected_display_data_with_rank, use_container_width = True)
|
207 |
+
|
208 |
+
# Show plot
|
209 |
+
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")
|
210 |
+
|
211 |
+
# Checkbox to show text labels
|
212 |
+
show_labels = st.checkbox("Show Text Labels?", value = False)
|
213 |
|
214 |
# Plot rank vs. score using matplotlib
|
|
|
215 |
fig, ax = plt.subplots(figsize = (10, 6))
|
216 |
+
ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
|
217 |
ax.set_xlabel('Rank', fontsize = 12)
|
218 |
ax.set_ylabel('Score', fontsize = 12)
|
219 |
ax.set_xlim(1, display_data['Rank'].max())
|
220 |
|
221 |
+
# Get color palette
|
222 |
+
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
|
223 |
+
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
|
224 |
+
|
225 |
# Add vertical line for selected nodes
|
226 |
for i, node in selected_display_data.iterrows():
|
227 |
+
ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
|
228 |
+
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
|
229 |
+
if show_labels:
|
230 |
+
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)
|
231 |
+
|
232 |
+
# Add legend
|
233 |
+
ax.legend(loc = 'upper right', fontsize = 10)
|
234 |
+
ax.grid(alpha = 0.2, zorder=0)
|
235 |
|
|
|
236 |
st.pyplot(fig)
|
237 |
|
238 |
|