aksell commited on
Commit
58c7b8d
1 Parent(s): ebbe380

Visualize attention pairs on structure

Browse files
protention/attention.py CHANGED
@@ -66,21 +66,11 @@ def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
66
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
67
  return tokenizer, model
68
 
69
- @st.cache
70
  def get_attention(
71
- pdb_code: str, model: ModelType = ModelType.TAPE_BERT
72
  ):
73
- """
74
- Get attention from T5
75
- """
76
- # fetch structure
77
- structure = get_structure(pdb_code)
78
- # Get list of sequences
79
- sequences = get_sequences(structure)
80
- # TODO handle multiple sequences
81
- sequence = sequences[0]
82
-
83
- match model.name:
84
  case ModelType.TAPE_BERT:
85
  tokenizer, model = get_tape_bert()
86
  token_idxs = tokenizer.encode(sequence).tolist()
@@ -91,9 +81,47 @@ def get_attention(
91
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
92
  attns = torch.stack([attn.squeeze(0) for attn in attns])
93
  case ModelType.PROT_T5:
 
94
  # Space separate sequences
95
  sequences = [" ".join(sequence) for sequence in sequences]
96
  tokenizer, model = get_protT5()
97
-
 
98
  return attns
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
67
  return tokenizer, model
68
 
69
+ @st.cache_data
70
  def get_attention(
71
+ sequence: list[str], model_type: ModelType = ModelType.TAPE_BERT
72
  ):
73
+ match model_type:
 
 
 
 
 
 
 
 
 
 
74
  case ModelType.TAPE_BERT:
75
  tokenizer, model = get_tape_bert()
76
  token_idxs = tokenizer.encode(sequence).tolist()
 
81
  attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
  attns = torch.stack([attn.squeeze(0) for attn in attns])
83
  case ModelType.PROT_T5:
84
+ attns = None
85
  # Space separate sequences
86
  sequences = [" ".join(sequence) for sequence in sequences]
87
  tokenizer, model = get_protT5()
88
+ case _:
89
+ raise ValueError(f"Model {model_type} not supported")
90
  return attns
91
 
92
+ def unidirectional_sum_filtered(attention, layer, head, threshold):
93
+ num_layers, num_heads, seq_len, _ = attention.shape
94
+ attention_head = attention[layer, head]
95
+ unidirectional_sum_for_head = []
96
+ for i in range(seq_len):
97
+ for j in range(i, seq_len):
98
+ # Attention matrices for BERT models are asymetric.
99
+ # Bidirectional attention is reduced to one value by adding the
100
+ # attention values
101
+ # TODO think... does this operation make sense?
102
+ sum = attention_head[i, j].item() + attention_head[j, i].item()
103
+ if sum >= threshold:
104
+ unidirectional_sum_for_head.append((sum, i, j))
105
+ return unidirectional_sum_for_head
106
+
107
+ @st.cache_data
108
+ def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
109
+ # fetch structure
110
+ structure = get_structure(pdb_code=pdb_code)
111
+ # Get list of sequences
112
+ sequences = get_sequences(structure)
113
+
114
+ attention_pairs = []
115
+ for i, sequence in enumerate(sequences):
116
+ attention = get_attention(sequence=sequence, model_type=model_type)
117
+ attention_unidirectional = unidirectional_sum_filtered(attention, layer, head, threshold)
118
+ chain = list(structure.get_chains())[i]
119
+ for attn_value, res_1, res_2 in attention_unidirectional:
120
+ try:
121
+ coord_1 = chain[res_1]["CA"].coord.tolist()
122
+ coord_2 = chain[res_2]["CA"].coord.tolist()
123
+ except KeyError:
124
+ continue
125
+ attention_pairs.append((attn_value, coord_1, coord_2))
126
+
127
+ return attention_pairs
protention/streamlit/Attention_On_Structure.py CHANGED
@@ -3,7 +3,7 @@ import stmol
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
- from protention.attention import Model, ModelType, get_attention
7
 
8
  st.sidebar.title("pLM Attention Visualization")
9
 
@@ -27,12 +27,14 @@ with right:
27
 
28
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
29
 
30
- attention = get_attention(pdb_id, model=selected_model.name)
31
 
32
  def get_3dview(pdb):
33
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
34
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
35
  stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
 
 
36
  return xyzview
37
 
38
 
 
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
+ from protention.attention import Model, ModelType, get_attention_pairs
7
 
8
  st.sidebar.title("pLM Attention Visualization")
9
 
 
27
 
28
  min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
29
 
30
+ attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
31
 
32
  def get_3dview(pdb):
33
  xyzview = py3Dmol.view(query=f"pdb:{pdb}")
34
  xyzview.setStyle({"cartoon": {"color": "spectrum"}})
35
  stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
36
+ for att_weight, first, second in attention_pairs:
37
+ stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight*3, cylColor='red', dashed=False)
38
  return xyzview
39
 
40
 
tests/test_attention.py CHANGED
@@ -3,7 +3,8 @@ from Bio.PDB.Structure import Structure
3
  from transformers import T5EncoderModel, T5Tokenizer
4
 
5
  from protention.attention import (ModelType, get_attention, get_protT5,
6
- get_sequences, get_structure)
 
7
 
8
 
9
  def test_get_structure():
@@ -38,7 +39,27 @@ def test_get_protT5():
38
 
39
  def test_get_attention_tape():
40
 
41
- result = get_attention("1AKE", model=ModelType.tape_bert)
42
 
43
  assert result is not None
44
  assert result.shape == torch.Size([12,12,456,456])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import T5EncoderModel, T5Tokenizer
4
 
5
  from protention.attention import (ModelType, get_attention, get_protT5,
6
+ get_sequences, get_structure,
7
+ unidirectional_sum_filtered)
8
 
9
 
10
  def test_get_structure():
 
39
 
40
  def test_get_attention_tape():
41
 
42
+ result = get_attention("1AKE", model=ModelType.TAPE_BERT)
43
 
44
  assert result is not None
45
  assert result.shape == torch.Size([12,12,456,456])
46
+
47
+ def test_get_unidirection_sum_filtered():
48
+ # 1 head, 1 layer, 4 residues long attention tensor
49
+ attention= torch.tensor([[[[1, 2, 3, 4],
50
+ [2, 5, 6, 7],
51
+ [3, 6, 8, 9],
52
+ [4, 7, 9, 11]]]], dtype=torch.float32)
53
+
54
+ result = unidirectional_sum_filtered(attention, 0, 0, 0)
55
+
56
+ assert result is not None
57
+ assert len(result) == 10
58
+
59
+ attention= torch.tensor([[[[1, 2, 3],
60
+ [2, 5, 6],
61
+ [4, 7, 91]]]], dtype=torch.float32)
62
+
63
+ result = unidirectional_sum_filtered(attention, 0, 0, 0)
64
+
65
+ assert len(result) == 6