aksell commited on
Commit
58b87e3
1 Parent(s): 3de389d

Allow selecting individual chains

Browse files
Files changed (3) hide show
  1. hexviz/app.py +27 -4
  2. hexviz/attention.py +25 -17
  3. tests/test_attention.py +1 -1
hexviz/app.py CHANGED
@@ -3,7 +3,7 @@ import stmol
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
- from hexviz.attention import get_attention_pairs, get_structure
7
  from hexviz.models import Model, ModelType
8
 
9
  st.title("Attention Visualization on proteins")
@@ -21,7 +21,20 @@ models = [
21
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
22
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
23
 
24
- pdb_id = st.text_input("PDB ID", "1I60")
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  left, right = st.columns(2)
27
  with left:
@@ -35,21 +48,31 @@ with right:
35
  with st.expander("Configure parameters", expanded=False):
36
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
37
  try:
38
- structure = get_structure(pdb_id)
39
  ec_class = structure.header["compound"]["1"]["ec"]
40
  except KeyError:
41
  ec_class = None
42
  if ec_class and selected_model.name == ModelType.ZymCTRL:
43
  ec_class = st.text_input("Enzyme classification number fetched from PDB", ec_class)
44
 
45
- attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
46
 
47
  def get_3dview(pdb):
48
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
49
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
50
  stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
 
 
 
 
 
 
51
  for att_weight, first, second in attention_pairs:
52
  stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
 
 
 
 
 
53
  return xyzview
54
 
55
 
 
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
+ from hexviz.attention import get_attention_pairs, get_chains, get_structure
7
  from hexviz.models import Model, ModelType
8
 
9
  st.title("Attention Visualization on proteins")
 
21
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
22
  selected_model = next((model for model in models if model.name.value == selected_model_name), None)
23
 
24
+ st.sidebar.title("Settings")
25
+
26
+ pdb_id = st.sidebar.text_input(
27
+ label="PDB ID",
28
+ value="4RW0",
29
+ )
30
+ structure = get_structure(pdb_id)
31
+ chains = get_chains(structure)
32
+ selected_chains = st.sidebar.multiselect(label="Chain(s)", options=chains, default=chains)
33
+
34
+ hl_resi_list = st.sidebar.multiselect(label="Highlight Residues",options=list(range(1,5000)))
35
+
36
+ label_resi = st.sidebar.checkbox(label="Label Residues", value=True)
37
+
38
 
39
  left, right = st.columns(2)
40
  with left:
 
48
  with st.expander("Configure parameters", expanded=False):
49
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
50
  try:
 
51
  ec_class = structure.header["compound"]["1"]["ec"]
52
  except KeyError:
53
  ec_class = None
54
  if ec_class and selected_model.name == ModelType.ZymCTRL:
55
  ec_class = st.text_input("Enzyme classification number fetched from PDB", ec_class)
56
 
57
+ attention_pairs = get_attention_pairs(pdb_id, chain_ids=selected_chains, layer=layer, head=head, threshold=min_attn, model_type=selected_model.name)
58
 
59
  def get_3dview(pdb):
60
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
61
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
62
  stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
63
+
64
+
65
+ hidden_chains = [x for x in chains if x not in selected_chains]
66
+ for chain in hidden_chains:
67
+ xyzview.setStyle({"chain": chain},{"cross":{"hidden":"true"}})
68
+
69
  for att_weight, first, second in attention_pairs:
70
  stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight, cylColor='red', dashed=False)
71
+
72
+ if label_resi:
73
+ for hl_resi in hl_resi_list:
74
+ xyzview.addResLabels({"chain": chain,"resi": hl_resi},
75
+ {"backgroundColor": "lightgray","fontColor": "black","backgroundOpacity": 0.5})
76
  return xyzview
77
 
78
 
hexviz/attention.py CHANGED
@@ -1,6 +1,5 @@
1
- from enum import Enum
2
  from io import StringIO
3
- from typing import List, Tuple
4
  from urllib import request
5
 
6
  import streamlit as st
@@ -21,20 +20,27 @@ def get_structure(pdb_code: str) -> Structure:
21
  structure = parser.get_structure(pdb_code, file)
22
  return structure
23
 
24
- def get_sequences(structure: Structure) -> List[str]:
25
  """
26
- Get list of sequences with residues on a single letter format
 
 
 
 
 
 
 
 
 
 
27
 
28
  Residues not in the standard 20 amino acids are replaced with X
29
  """
30
- sequences = []
31
- for seq in structure.get_chains():
32
- residues = [residue.get_resname() for residue in seq.get_residues()]
33
- # TODO ask if using protein_letters_3to1_extended makes sense
34
- residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
35
 
36
- sequences.append("".join(list(residues_single_letter)))
37
- return sequences
38
 
39
  @st.cache
40
  def get_attention(
@@ -100,17 +106,19 @@ def unidirectional_avg_filtered(attention, layer, head, threshold):
100
  return unidirectional_avg_for_head
101
 
102
  @st.cache
103
- def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
104
- # fetch structure
105
  structure = get_structure(pdb_code=pdb_code)
106
- # Get list of sequences
107
- sequences = get_sequences(structure)
 
 
 
108
 
109
  attention_pairs = []
110
- for i, sequence in enumerate(sequences):
 
111
  attention = get_attention(sequence=sequence, model_type=model_type)
112
  attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
113
- chain = list(structure.get_chains())[i]
114
  for attn_value, res_1, res_2 in attention_unidirectional:
115
  try:
116
  coord_1 = chain[res_1]["CA"].coord.tolist()
 
 
1
  from io import StringIO
2
+ from typing import List, Optional
3
  from urllib import request
4
 
5
  import streamlit as st
 
20
  structure = parser.get_structure(pdb_code, file)
21
  return structure
22
 
23
+ def get_chains(structure: Structure) -> List[str]:
24
  """
25
+ Get list of chains in a structure
26
+ """
27
+ chains = []
28
+ for model in structure:
29
+ for chain in model.get_chains():
30
+ chains.append(chain.id)
31
+ return chains
32
+
33
+ def get_sequence(chain) -> str:
34
+ """
35
+ Get sequence from a chain
36
 
37
  Residues not in the standard 20 amino acids are replaced with X
38
  """
39
+ residues = [residue.get_resname() for residue in chain.get_residues()]
40
+ # TODO ask if using protein_letters_3to1_extended makes sense
41
+ residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
 
 
42
 
43
+ return "".join(list(residues_single_letter))
 
44
 
45
  @st.cache
46
  def get_attention(
 
106
  return unidirectional_avg_for_head
107
 
108
  @st.cache
109
+ def get_attention_pairs(pdb_code: str, layer: int, head: int, chain_ids: Optional[str] = None ,threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
 
110
  structure = get_structure(pdb_code=pdb_code)
111
+
112
+ if chain_ids:
113
+ chains = [ch for ch in structure.get_chains() if ch.id in chain_ids]
114
+ else:
115
+ chains = list(structure.get_chains())
116
 
117
  attention_pairs = []
118
+ for chain in chains:
119
+ sequence = get_sequence(chain)
120
  attention = get_attention(sequence=sequence, model_type=model_type)
121
  attention_unidirectional = unidirectional_avg_filtered(attention, layer, head, threshold)
 
122
  for attn_value, res_1, res_2 in attention_unidirectional:
123
  try:
124
  coord_1 = chain[res_1]["CA"].coord.tolist()
tests/test_attention.py CHANGED
@@ -70,7 +70,7 @@ def test_get_unidirection_avg_filtered():
70
  assert result is not None
71
  assert len(result) == 10
72
 
73
- attention= torch.tensor([[[[1, 2, 3],
74
  [2, 5, 6],
75
  [4, 7, 91]]]], dtype=torch.float32)
76
 
 
70
  assert result is not None
71
  assert len(result) == 10
72
 
73
+ attention = torch.tensor([[[[1, 2, 3],
74
  [2, 5, 6],
75
  [4, 7, 91]]]], dtype=torch.float32)
76