Spaces:
Sleeping
Sleeping
Add ZymCTRL model
Browse filesNeeds some tweaking. Massive attention to the same residue
makes weird disks in the attention bars. It is also noticeably
slow on a CPU only instance.
- hexviz/app.py +1 -0
- hexviz/attention.py +49 -6
- tests/test_attention.py +35 -4
hexviz/app.py
CHANGED
@@ -9,6 +9,7 @@ st.title("pLM Attention Visualization")
|
|
9 |
|
10 |
# Define list of model types
|
11 |
models = [
|
|
|
12 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
13 |
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
14 |
]
|
|
|
9 |
|
10 |
# Define list of model types
|
11 |
models = [
|
12 |
+
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
13 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
14 |
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
15 |
]
|
hexviz/attention.py
CHANGED
@@ -7,12 +7,14 @@ import streamlit as st
|
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
from tape import ProteinBertModel, TAPETokenizer
|
10 |
-
from transformers import T5EncoderModel,
|
|
|
11 |
|
12 |
|
13 |
class ModelType(str, Enum):
|
14 |
TAPE_BERT = "bert-base"
|
15 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
|
|
16 |
|
17 |
|
18 |
class Model:
|
@@ -45,7 +47,7 @@ def get_sequences(structure: Structure) -> List[str]:
|
|
45 |
# TODO ask if using protein_letters_3to1_extended makes sense
|
46 |
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
|
47 |
|
48 |
-
sequences.append(list(residues_single_letter))
|
49 |
return sequences
|
50 |
|
51 |
@st.cache
|
@@ -69,18 +71,59 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
|
69 |
return tokenizer, model
|
70 |
|
71 |
@st.cache
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def get_attention(
|
73 |
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
74 |
):
|
|
|
|
|
|
|
75 |
if model_type == ModelType.TAPE_BERT:
|
76 |
tokenizer, model = get_tape_bert()
|
77 |
token_idxs = tokenizer.encode(sequence).tolist()
|
78 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
79 |
with torch.no_grad():
|
80 |
-
|
81 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
elif model_type == ModelType.PROT_T5:
|
85 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
86 |
# Introduce white-space between all amino acids
|
@@ -98,7 +141,7 @@ def get_attention(
|
|
98 |
else:
|
99 |
raise ValueError(f"Model {model_type} not supported")
|
100 |
|
101 |
-
return
|
102 |
|
103 |
def unidirectional_sum_filtered(attention, layer, head, threshold):
|
104 |
num_layers, num_heads, seq_len, _ = attention.shape
|
|
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
from tape import ProteinBertModel, TAPETokenizer
|
10 |
+
from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
|
11 |
+
T5Tokenizer)
|
12 |
|
13 |
|
14 |
class ModelType(str, Enum):
|
15 |
TAPE_BERT = "bert-base"
|
16 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
17 |
+
ZymCTRL = "zymctrl"
|
18 |
|
19 |
|
20 |
class Model:
|
|
|
47 |
# TODO ask if using protein_letters_3to1_extended makes sense
|
48 |
residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
|
49 |
|
50 |
+
sequences.append("".join(list(residues_single_letter)))
|
51 |
return sequences
|
52 |
|
53 |
@st.cache
|
|
|
71 |
return tokenizer, model
|
72 |
|
73 |
@st.cache
|
74 |
+
def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
75 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained('nferruz/ZymCTRL')
|
77 |
+
model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
|
78 |
+
return tokenizer, model
|
79 |
+
|
80 |
def get_attention(
|
81 |
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
82 |
):
|
83 |
+
"""
|
84 |
+
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
85 |
+
"""
|
86 |
if model_type == ModelType.TAPE_BERT:
|
87 |
tokenizer, model = get_tape_bert()
|
88 |
token_idxs = tokenizer.encode(sequence).tolist()
|
89 |
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
90 |
with torch.no_grad():
|
91 |
+
attentions = model(inputs)[-1]
|
92 |
# Remove attention from <CLS> (first) and <SEP> (last) token
|
93 |
+
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
94 |
+
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
95 |
+
|
96 |
+
elif model_type == ModelType.ZymCTRL:
|
97 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
98 |
+
tokenizer, model = get_zymctrl()
|
99 |
+
inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
|
100 |
+
attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
104 |
+
attentions = outputs.attentions
|
105 |
+
if attentions[0].shape[-1] == attentions[0].shape[-2] == 1:
|
106 |
+
reshaped = [attention.view(attention.shape[1], attention.shape[0]) for attention in attentions]
|
107 |
+
n_residues = reshaped[0].shape[-1]
|
108 |
+
n_heads = 16
|
109 |
+
i,j = torch.triu_indices(n_residues, n_residues)
|
110 |
+
|
111 |
+
attentions_symmetric = []
|
112 |
+
# Make symmetric attention matrix
|
113 |
+
for attention in reshaped:
|
114 |
+
x = torch.zeros(n_heads, n_residues, n_residues)
|
115 |
+
x[:,i,j] = attention
|
116 |
+
x[:,j,i] = attention
|
117 |
+
attentions_symmetric.append(x)
|
118 |
+
attentions = torch.stack([attention for attention in attentions_symmetric])
|
119 |
+
else:
|
120 |
+
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
121 |
+
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
122 |
+
|
123 |
+
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
124 |
+
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
125 |
+
attentions = attention_stacked
|
126 |
+
|
127 |
elif model_type == ModelType.PROT_T5:
|
128 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
129 |
# Introduce white-space between all amino acids
|
|
|
141 |
else:
|
142 |
raise ValueError(f"Model {model_type} not supported")
|
143 |
|
144 |
+
return attentions
|
145 |
|
146 |
def unidirectional_sum_filtered(attention, layer, head, threshold):
|
147 |
num_layers, num_heads, seq_len, _ = attention.shape
|
tests/test_attention.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import torch
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
-
from transformers import T5EncoderModel,
|
|
|
4 |
|
5 |
from hexviz.attention import (ModelType, get_attention, get_protT5,
|
6 |
-
get_sequences, get_structure,
|
7 |
unidirectional_sum_filtered)
|
8 |
|
9 |
|
@@ -37,12 +38,42 @@ def test_get_protT5():
|
|
37 |
assert isinstance(tokenizer, T5Tokenizer)
|
38 |
assert isinstance(model, T5EncoderModel)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def test_get_attention_tape():
|
|
|
|
|
41 |
|
42 |
-
result = get_attention(
|
43 |
|
44 |
assert result is not None
|
45 |
-
assert result.shape == torch.Size([12,12,
|
46 |
|
47 |
def test_get_unidirection_sum_filtered():
|
48 |
# 1 head, 1 layer, 4 residues long attention tensor
|
|
|
1 |
import torch
|
2 |
from Bio.PDB.Structure import Structure
|
3 |
+
from transformers import (GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel,
|
4 |
+
T5Tokenizer)
|
5 |
|
6 |
from hexviz.attention import (ModelType, get_attention, get_protT5,
|
7 |
+
get_sequences, get_structure, get_zymctrl,
|
8 |
unidirectional_sum_filtered)
|
9 |
|
10 |
|
|
|
38 |
assert isinstance(tokenizer, T5Tokenizer)
|
39 |
assert isinstance(model, T5EncoderModel)
|
40 |
|
41 |
+
def test_get_zymctrl():
|
42 |
+
result = get_zymctrl()
|
43 |
+
|
44 |
+
assert result is not None
|
45 |
+
assert isinstance(result, tuple)
|
46 |
+
|
47 |
+
tokenizer, model = result
|
48 |
+
|
49 |
+
assert isinstance(tokenizer, GPT2TokenizerFast)
|
50 |
+
assert isinstance(model, GPT2LMHeadModel)
|
51 |
+
|
52 |
+
def test_get_attention_zymctrl():
|
53 |
+
|
54 |
+
result = get_attention("GGG", model_type=ModelType.ZymCTRL)
|
55 |
+
|
56 |
+
assert result is not None
|
57 |
+
assert result.shape == torch.Size([36,16,3,3])
|
58 |
+
|
59 |
+
def test_get_attention_zymctrl_long_chain():
|
60 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
61 |
+
|
62 |
+
sequences = get_sequences(structure)
|
63 |
+
|
64 |
+
result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
|
65 |
+
|
66 |
+
assert result is not None
|
67 |
+
assert result.shape == torch.Size([36,16,13,13])
|
68 |
+
|
69 |
def test_get_attention_tape():
|
70 |
+
structure = get_structure(pdb_code="6A5J") # 13 residues long
|
71 |
+
sequences = get_sequences(structure)
|
72 |
|
73 |
+
result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
|
74 |
|
75 |
assert result is not None
|
76 |
+
assert result.shape == torch.Size([12,12,13,13])
|
77 |
|
78 |
def test_get_unidirection_sum_filtered():
|
79 |
# 1 head, 1 layer, 4 residues long attention tensor
|