aksell commited on
Commit
b7cb94d
1 Parent(s): 1b13d7f

Add table of top_n attention pairs

Browse files
Files changed (1) hide show
  1. hexviz/app.py +15 -0
hexviz/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import py3Dmol
2
  import stmol
3
  import streamlit as st
@@ -115,6 +116,20 @@ showmol(xyzview, height=500, width=800)
115
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  """
119
  More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
120
  """
 
1
+ import pandas as pd
2
  import py3Dmol
3
  import stmol
4
  import streamlit as st
 
116
  st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
117
 
118
 
119
+ chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
120
+ data = []
121
+ for att_weight, _ , _ , chain, first, second in top_n:
122
+ res1 = chain_dict[chain][first]
123
+ res2 = chain_dict[chain][second]
124
+ el = (att_weight, f"{res1.resname:3}{res1.id[1]:0>3} - {res2.resname:3}{res2.id[1]:0>3} ({chain})")
125
+ data.append(el)
126
+ # st.write(f"Attention weight: {att_weight:.2f} | Residue pair: {structure.get_chain_id(chain)[first].get_resname()}-{structure.get_chain(chain)[first].full_id[3]}{chain.id}<-->{chain[second].get_resname()}")
127
+
128
+ df = pd.DataFrame(data, columns=['Avg attention', 'Residue pair'])
129
+ f"Top {n_pairs} attention pairs:"
130
+ st.table(df)
131
+
132
+
133
  """
134
  More models will be added soon. The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
135
  """