ayushnoori commited on
Commit
48009b1
·
1 Parent(s): b88647e

Update predict plot with CIPHER ASAP demo changes

Browse files
Files changed (1) hide show
  1. pages/predict.py +29 -9
pages/predict.py CHANGED
@@ -186,33 +186,53 @@ with st.spinner('Computing predictions...'):
186
 
187
  if show_val:
188
  # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
189
- selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].style.map(style_val, subset=val_relations)
 
190
  else:
191
- selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
 
 
 
 
 
192
 
193
  # Show filtered nodes
194
  if target_node_type not in ['disease', 'anatomy']:
195
- st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
196
  column_config={"Database": st.column_config.LinkColumn(width = "small",
197
  help = "Click to visit external database.",
198
  display_text = display_database)})
199
  else:
200
- st.dataframe(selected_display_data, use_container_width = True)
 
 
 
 
 
 
201
 
202
  # Plot rank vs. score using matplotlib
203
- st.markdown("**Rank vs. Score**")
204
  fig, ax = plt.subplots(figsize = (10, 6))
205
- ax.plot(display_data['Rank'], display_data['Score'])
206
  ax.set_xlabel('Rank', fontsize = 12)
207
  ax.set_ylabel('Score', fontsize = 12)
208
  ax.set_xlim(1, display_data['Rank'].max())
209
 
 
 
 
 
210
  # Add vertical line for selected nodes
211
  for i, node in selected_display_data.iterrows():
212
- ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
213
- ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
 
 
 
 
 
 
214
 
215
- # Show plot
216
  st.pyplot(fig)
217
 
218
 
 
186
 
187
  if show_val:
188
  # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
189
+ selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
190
+ selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
191
  else:
192
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
193
+ selected_display_data = selected_display_data.reset_index(drop=True)
194
+
195
+ st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
196
+ selected_display_data_with_rank = selected_display_data.copy()
197
+ selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
198
 
199
  # Show filtered nodes
200
  if target_node_type not in ['disease', 'anatomy']:
201
+ st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
202
  column_config={"Database": st.column_config.LinkColumn(width = "small",
203
  help = "Click to visit external database.",
204
  display_text = display_database)})
205
  else:
206
+ st.dataframe(selected_display_data_with_rank, use_container_width = True)
207
+
208
+ # Show plot
209
+ st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")
210
+
211
+ # Checkbox to show text labels
212
+ show_labels = st.checkbox("Show Text Labels?", value = False)
213
 
214
  # Plot rank vs. score using matplotlib
 
215
  fig, ax = plt.subplots(figsize = (10, 6))
216
+ ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
217
  ax.set_xlabel('Rank', fontsize = 12)
218
  ax.set_ylabel('Score', fontsize = 12)
219
  ax.set_xlim(1, display_data['Rank'].max())
220
 
221
+ # Get color palette
222
+ # palette = plt.cm.get_cmap('tab10', len(selected_display_data))
223
+ palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
224
+
225
  # Add vertical line for selected nodes
226
  for i, node in selected_display_data.iterrows():
227
+ ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
228
+ ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
229
+ if show_labels:
230
+ ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)
231
+
232
+ # Add legend
233
+ ax.legend(loc = 'upper right', fontsize = 10)
234
+ ax.grid(alpha = 0.2, zorder=0)
235
 
 
236
  st.pyplot(fig)
237
 
238