Spaces:
Sleeping
Sleeping
Add PROT BERT
Browse files- hexviz/attention.py +14 -3
- hexviz/models.py +9 -2
- tests/test_attention.py +7 -0
hexviz/attention.py
CHANGED
@@ -6,11 +6,11 @@ from urllib import request
|
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
|
13 |
-
@st.cache
|
14 |
def get_structure(pdb_code: str) -> Structure:
|
15 |
"""
|
16 |
Get structure from PDB
|
@@ -83,6 +83,17 @@ def get_attention(
|
|
83 |
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
84 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
85 |
attentions = attention_stacked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
elif model_type == ModelType.PROT_T5:
|
88 |
# Introduce white-space between all amino acids
|
|
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
from Bio.PDB import PDBParser, Polypeptide, Structure
|
9 |
+
|
10 |
+
from hexviz.models import (ModelType, get_prot_bert, get_protgpt2, get_protT5,
|
11 |
+
get_tape_bert, get_zymctrl)
|
12 |
|
13 |
|
|
|
14 |
def get_structure(pdb_code: str) -> Structure:
|
15 |
"""
|
16 |
Get structure from PDB
|
|
|
83 |
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
84 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
85 |
attentions = attention_stacked
|
86 |
+
# TODO extend attentions to be per token, not per word piece
|
87 |
+
# simplest way to draw attention for multi residue token models for now
|
88 |
+
elif model_type == ModelType.PROT_BERT:
|
89 |
+
tokenizer, model = get_prot_bert()
|
90 |
+
token_idxs = tokenizer.encode(sequence)
|
91 |
+
inputs = torch.tensor(token_idxs).unsqueeze(0)
|
92 |
+
with torch.no_grad():
|
93 |
+
attentions = model(inputs)[-1]
|
94 |
+
# Remove attention from <CLS> (first) and <SEP> (last) token
|
95 |
+
attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
|
96 |
+
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
97 |
|
98 |
elif model_type == ModelType.PROT_T5:
|
99 |
# Introduce white-space between all amino acids
|
hexviz/models.py
CHANGED
@@ -4,8 +4,8 @@ from typing import Tuple
|
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
from tape import ProteinBertModel, TAPETokenizer
|
7 |
-
from transformers import (AutoTokenizer,
|
8 |
-
T5Tokenizer)
|
9 |
|
10 |
|
11 |
class ModelType(str, Enum):
|
@@ -13,6 +13,7 @@ class ModelType(str, Enum):
|
|
13 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
14 |
ZymCTRL = "ZymCTRL"
|
15 |
ProtGPT2 = "ProtGPT2"
|
|
|
16 |
|
17 |
|
18 |
class Model:
|
@@ -42,6 +43,12 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
|
|
42 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
43 |
return tokenizer, model
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
@st.cache
|
46 |
def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
47 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
from tape import ProteinBertModel, TAPETokenizer
|
7 |
+
from transformers import (AutoTokenizer, BertForMaskedLM, BertTokenizer,
|
8 |
+
GPT2LMHeadModel, T5EncoderModel, T5Tokenizer)
|
9 |
|
10 |
|
11 |
class ModelType(str, Enum):
|
|
|
13 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
14 |
ZymCTRL = "ZymCTRL"
|
15 |
ProtGPT2 = "ProtGPT2"
|
16 |
+
PROT_BERT = "ProtBert"
|
17 |
|
18 |
|
19 |
class Model:
|
|
|
43 |
model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
|
44 |
return tokenizer, model
|
45 |
|
46 |
+
@st.cache
|
47 |
+
def get_prot_bert() -> Tuple[BertTokenizer, BertForMaskedLM]:
|
48 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
|
49 |
+
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
|
50 |
+
return tokenizer, model
|
51 |
+
|
52 |
@st.cache
|
53 |
def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
54 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
tests/test_attention.py
CHANGED
@@ -51,6 +51,13 @@ def test_get_attention_tape():
|
|
51 |
assert result is not None
|
52 |
assert result.shape == torch.Size([12,12,13,13])
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def test_get_unidirection_sum_filtered():
|
55 |
# 1 head, 1 layer, 4 residues long attention tensor
|
56 |
attention= torch.tensor([[[[1, 2, 3, 4],
|
|
|
51 |
assert result is not None
|
52 |
assert result.shape == torch.Size([12,12,13,13])
|
53 |
|
54 |
+
def test_get_attention_prot_bert():
|
55 |
+
|
56 |
+
result = get_attention("GGG", model_type=ModelType.PROT_BERT)
|
57 |
+
|
58 |
+
assert result is not None
|
59 |
+
assert result.shape == torch.Size([30, 16, 3, 3])
|
60 |
+
|
61 |
def test_get_unidirection_sum_filtered():
|
62 |
# 1 head, 1 layer, 4 residues long attention tensor
|
63 |
attention= torch.tensor([[[[1, 2, 3, 4],
|