Spaces:
Sleeping
Sleeping
Visualize attention pairs on structure
Browse files- protention/attention.py +42 -14
- protention/streamlit/Attention_On_Structure.py +4 -2
- tests/test_attention.py +23 -2
protention/attention.py
CHANGED
@@ -66,21 +66,11 @@ def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
|
|
66 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
67 |
return tokenizer, model
|
68 |
|
69 |
-
@st.
|
70 |
def get_attention(
|
71 |
-
|
72 |
):
|
73 |
-
|
74 |
-
Get attention from T5
|
75 |
-
"""
|
76 |
-
# fetch structure
|
77 |
-
structure = get_structure(pdb_code)
|
78 |
-
# Get list of sequences
|
79 |
-
sequences = get_sequences(structure)
|
80 |
-
# TODO handle multiple sequences
|
81 |
-
sequence = sequences[0]
|
82 |
-
|
83 |
-
match model.name:
|
84 |
case ModelType.TAPE_BERT:
|
85 |
tokenizer, model = get_tape_bert()
|
86 |
token_idxs = tokenizer.encode(sequence).tolist()
|
@@ -91,9 +81,47 @@ def get_attention(
|
|
91 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
92 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
93 |
case ModelType.PROT_T5:
|
|
|
94 |
# Space separate sequences
|
95 |
sequences = [" ".join(sequence) for sequence in sequences]
|
96 |
tokenizer, model = get_protT5()
|
97 |
-
|
|
|
98 |
return attns
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
67 |
return tokenizer, model
|
68 |
|
69 |
+
@st.cache_data
|
70 |
def get_attention(
|
71 |
+
sequence: list[str], model_type: ModelType = ModelType.TAPE_BERT
|
72 |
):
|
73 |
+
match model_type:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
case ModelType.TAPE_BERT:
|
75 |
tokenizer, model = get_tape_bert()
|
76 |
token_idxs = tokenizer.encode(sequence).tolist()
|
|
|
81 |
attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
|
82 |
attns = torch.stack([attn.squeeze(0) for attn in attns])
|
83 |
case ModelType.PROT_T5:
|
84 |
+
attns = None
|
85 |
# Space separate sequences
|
86 |
sequences = [" ".join(sequence) for sequence in sequences]
|
87 |
tokenizer, model = get_protT5()
|
88 |
+
case _:
|
89 |
+
raise ValueError(f"Model {model_type} not supported")
|
90 |
return attns
|
91 |
|
92 |
+
def unidirectional_sum_filtered(attention, layer, head, threshold):
|
93 |
+
num_layers, num_heads, seq_len, _ = attention.shape
|
94 |
+
attention_head = attention[layer, head]
|
95 |
+
unidirectional_sum_for_head = []
|
96 |
+
for i in range(seq_len):
|
97 |
+
for j in range(i, seq_len):
|
98 |
+
# Attention matrices for BERT models are asymetric.
|
99 |
+
# Bidirectional attention is reduced to one value by adding the
|
100 |
+
# attention values
|
101 |
+
# TODO think... does this operation make sense?
|
102 |
+
sum = attention_head[i, j].item() + attention_head[j, i].item()
|
103 |
+
if sum >= threshold:
|
104 |
+
unidirectional_sum_for_head.append((sum, i, j))
|
105 |
+
return unidirectional_sum_for_head
|
106 |
+
|
107 |
+
@st.cache_data
|
108 |
+
def get_attention_pairs(pdb_code: str, layer: int, head: int, threshold: int = 0.2, model_type: ModelType = ModelType.TAPE_BERT):
|
109 |
+
# fetch structure
|
110 |
+
structure = get_structure(pdb_code=pdb_code)
|
111 |
+
# Get list of sequences
|
112 |
+
sequences = get_sequences(structure)
|
113 |
+
|
114 |
+
attention_pairs = []
|
115 |
+
for i, sequence in enumerate(sequences):
|
116 |
+
attention = get_attention(sequence=sequence, model_type=model_type)
|
117 |
+
attention_unidirectional = unidirectional_sum_filtered(attention, layer, head, threshold)
|
118 |
+
chain = list(structure.get_chains())[i]
|
119 |
+
for attn_value, res_1, res_2 in attention_unidirectional:
|
120 |
+
try:
|
121 |
+
coord_1 = chain[res_1]["CA"].coord.tolist()
|
122 |
+
coord_2 = chain[res_2]["CA"].coord.tolist()
|
123 |
+
except KeyError:
|
124 |
+
continue
|
125 |
+
attention_pairs.append((attn_value, coord_1, coord_2))
|
126 |
+
|
127 |
+
return attention_pairs
|
protention/streamlit/Attention_On_Structure.py
CHANGED
@@ -3,7 +3,7 @@ import stmol
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
-
from protention.attention import Model, ModelType,
|
7 |
|
8 |
st.sidebar.title("pLM Attention Visualization")
|
9 |
|
@@ -27,12 +27,14 @@ with right:
|
|
27 |
|
28 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
|
29 |
|
30 |
-
|
31 |
|
32 |
def get_3dview(pdb):
|
33 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
34 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
35 |
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
|
|
|
|
|
36 |
return xyzview
|
37 |
|
38 |
|
|
|
3 |
import streamlit as st
|
4 |
from stmol import showmol
|
5 |
|
6 |
+
from protention.attention import Model, ModelType, get_attention_pairs
|
7 |
|
8 |
st.sidebar.title("pLM Attention Visualization")
|
9 |
|
|
|
27 |
|
28 |
min_attn = st.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.15)
|
29 |
|
30 |
+
attention_pairs = get_attention_pairs(pdb_id, layer, head, min_attn, model_type=selected_model.name)
|
31 |
|
32 |
def get_3dview(pdb):
|
33 |
xyzview = py3Dmol.view(query=f"pdb:{pdb}")
|
34 |
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
|
35 |
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
|
36 |
+
for att_weight, first, second in attention_pairs:
|
37 |
+
stmol.add_cylinder(xyzview, start=first, end=second, cylradius=att_weight*3, cylColor='red', dashed=False)
|
38 |
return xyzview
|
39 |
|
40 |
|
tests/test_attention.py
CHANGED
@@ -3,7 +3,8 @@ from Bio.PDB.Structure import Structure
|
|
3 |
from transformers import T5EncoderModel, T5Tokenizer
|
4 |
|
5 |
from protention.attention import (ModelType, get_attention, get_protT5,
|
6 |
-
get_sequences, get_structure
|
|
|
7 |
|
8 |
|
9 |
def test_get_structure():
|
@@ -38,7 +39,27 @@ def test_get_protT5():
|
|
38 |
|
39 |
def test_get_attention_tape():
|
40 |
|
41 |
-
result = get_attention("1AKE", model=ModelType.
|
42 |
|
43 |
assert result is not None
|
44 |
assert result.shape == torch.Size([12,12,456,456])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers import T5EncoderModel, T5Tokenizer
|
4 |
|
5 |
from protention.attention import (ModelType, get_attention, get_protT5,
|
6 |
+
get_sequences, get_structure,
|
7 |
+
unidirectional_sum_filtered)
|
8 |
|
9 |
|
10 |
def test_get_structure():
|
|
|
39 |
|
40 |
def test_get_attention_tape():
|
41 |
|
42 |
+
result = get_attention("1AKE", model=ModelType.TAPE_BERT)
|
43 |
|
44 |
assert result is not None
|
45 |
assert result.shape == torch.Size([12,12,456,456])
|
46 |
+
|
47 |
+
def test_get_unidirection_sum_filtered():
|
48 |
+
# 1 head, 1 layer, 4 residues long attention tensor
|
49 |
+
attention= torch.tensor([[[[1, 2, 3, 4],
|
50 |
+
[2, 5, 6, 7],
|
51 |
+
[3, 6, 8, 9],
|
52 |
+
[4, 7, 9, 11]]]], dtype=torch.float32)
|
53 |
+
|
54 |
+
result = unidirectional_sum_filtered(attention, 0, 0, 0)
|
55 |
+
|
56 |
+
assert result is not None
|
57 |
+
assert len(result) == 10
|
58 |
+
|
59 |
+
attention= torch.tensor([[[[1, 2, 3],
|
60 |
+
[2, 5, 6],
|
61 |
+
[4, 7, 91]]]], dtype=torch.float32)
|
62 |
+
|
63 |
+
result = unidirectional_sum_filtered(attention, 0, 0, 0)
|
64 |
+
|
65 |
+
assert len(result) == 6
|