Spaces:
Sleeping
Sleeping
Remove Prot-T5 and ProtGPT2
Browse filesThe attention visualization was not implemented since these
models don't tokenize single residues, and the models are
too large to be loaded with the current infrastructure.
Could be added back in the future if needed.
- hexviz/attention.py +1 -29
- hexviz/models.py +4 -27
- tests/test_models.py +2 -14
hexviz/attention.py
CHANGED
@@ -7,8 +7,7 @@ import streamlit as st
|
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
|
10 |
-
from hexviz.models import
|
11 |
-
get_tape_bert, get_zymctrl)
|
12 |
|
13 |
|
14 |
def get_structure(pdb_code: str) -> Structure:
|
@@ -71,20 +70,6 @@ def get_attention(
|
|
71 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
72 |
attentions = attention_stacked
|
73 |
|
74 |
-
elif model_type == ModelType.ProtGPT2:
|
75 |
-
tokenizer, model = get_protgpt2()
|
76 |
-
input_ids = tokenizer.encode(input, return_tensors='pt').to(device)
|
77 |
-
with torch.no_grad():
|
78 |
-
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
79 |
-
attentions = outputs.attentions
|
80 |
-
|
81 |
-
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
82 |
-
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
83 |
-
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
84 |
-
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
85 |
-
attentions = attention_stacked
|
86 |
-
# TODO extend attentions to be per token, not per word piece
|
87 |
-
# simplest way to draw attention for multi residue token models for now
|
88 |
elif model_type == ModelType.PROT_BERT:
|
89 |
tokenizer, model = get_prot_bert()
|
90 |
token_idxs = tokenizer.encode(sequence)
|
@@ -95,19 +80,6 @@ def get_attention(
|
|
95 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
96 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
97 |
|
98 |
-
elif model_type == ModelType.PROT_T5:
|
99 |
-
# Introduce white-space between all amino acids
|
100 |
-
sequence = " ".join(sequence)
|
101 |
-
# tokenize sequences and pad up to the longest sequence in the batch
|
102 |
-
ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
|
103 |
-
|
104 |
-
input_ids = torch.tensor(ids['input_ids']).to(device)
|
105 |
-
attention_mask = torch.tensor(ids['attention_mask']).to(device)
|
106 |
-
|
107 |
-
with torch.no_grad():
|
108 |
-
attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
|
109 |
-
|
110 |
-
tokenizer, model = get_protT5()
|
111 |
else:
|
112 |
raise ValueError(f"Model {model_type} not supported")
|
113 |
|
|
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
|
10 |
+
from hexviz.models import ModelType, get_prot_bert, get_tape_bert, get_zymctrl
|
|
|
11 |
|
12 |
|
13 |
def get_structure(pdb_code: str) -> Structure:
|
|
|
70 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
71 |
attentions = attention_stacked
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
elif model_type == ModelType.PROT_BERT:
|
74 |
tokenizer, model = get_prot_bert()
|
75 |
token_idxs = tokenizer.encode(sequence)
|
|
|
80 |
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
81 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
else:
|
84 |
raise ValueError(f"Model {model_type} not supported")
|
85 |
|
hexviz/models.py
CHANGED
@@ -5,14 +5,12 @@ import streamlit as st
|
|
5 |
import torch
|
6 |
from tape import ProteinBertModel, TAPETokenizer
|
7 |
from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
|
8 |
-
GPT2LMHeadModel
|
9 |
|
10 |
|
11 |
class ModelType(str, Enum):
|
12 |
TAPE_BERT = "TAPE-BERT"
|
13 |
-
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
14 |
ZymCTRL = "ZymCTRL"
|
15 |
-
ProtGPT2 = "ProtGPT2"
|
16 |
PROT_BERT = "ProtBert"
|
17 |
|
18 |
|
@@ -22,20 +20,6 @@ class Model:
|
|
22 |
self.layers: int = layers
|
23 |
self.heads: int = heads
|
24 |
|
25 |
-
@st.cache
|
26 |
-
def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
|
27 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
28 |
-
tokenizer = T5Tokenizer.from_pretrained(
|
29 |
-
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
30 |
-
)
|
31 |
-
|
32 |
-
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
|
33 |
-
device
|
34 |
-
)
|
35 |
-
|
36 |
-
model.full() if device == "cpu" else model.half()
|
37 |
-
|
38 |
-
return tokenizer, model
|
39 |
|
40 |
@st.cache
|
41 |
def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
@@ -43,12 +27,6 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
|
43 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
44 |
return tokenizer, model
|
45 |
|
46 |
-
@st.cache
|
47 |
-
def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
|
48 |
-
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
|
49 |
-
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
|
50 |
-
return tokenizer, model
|
51 |
-
|
52 |
@st.cache
|
53 |
def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
54 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -57,8 +35,7 @@ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
|
57 |
return tokenizer, model
|
58 |
|
59 |
@st.cache
|
60 |
-
def
|
61 |
-
|
62 |
-
|
63 |
-
model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
|
64 |
return tokenizer, model
|
|
|
5 |
import torch
|
6 |
from tape import ProteinBertModel, TAPETokenizer
|
7 |
from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
|
8 |
+
GPT2LMHeadModel)
|
9 |
|
10 |
|
11 |
class ModelType(str, Enum):
|
12 |
TAPE_BERT = "TAPE-BERT"
|
|
|
13 |
ZymCTRL = "ZymCTRL"
|
|
|
14 |
PROT_BERT = "ProtBert"
|
15 |
|
16 |
|
|
|
20 |
self.layers: int = layers
|
21 |
self.heads: int = heads
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
@st.cache
|
25 |
def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
|
|
27 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
28 |
return tokenizer, model
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
@st.cache
|
31 |
def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
32 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
35 |
return tokenizer, model
|
36 |
|
37 |
@st.cache
|
38 |
+
def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
|
39 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
|
40 |
+
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
|
|
|
41 |
return tokenizer, model
|
tests/test_models.py
CHANGED
@@ -1,21 +1,9 @@
|
|
1 |
|
2 |
-
from transformers import
|
3 |
-
T5Tokenizer)
|
4 |
|
5 |
-
from hexviz.models import
|
6 |
|
7 |
|
8 |
-
def test_get_protT5():
|
9 |
-
result = get_protT5()
|
10 |
-
|
11 |
-
assert result is not None
|
12 |
-
assert isinstance(result, tuple)
|
13 |
-
|
14 |
-
tokenizer, model = result
|
15 |
-
|
16 |
-
assert isinstance(tokenizer, T5Tokenizer)
|
17 |
-
assert isinstance(model, T5EncoderModel)
|
18 |
-
|
19 |
def test_get_zymctrl():
|
20 |
result = get_zymctrl()
|
21 |
|
|
|
1 |
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
|
|
3 |
|
4 |
+
from hexviz.models import get_zymctrl
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def test_get_zymctrl():
|
8 |
result = get_zymctrl()
|
9 |
|