Spaces:
Sleeping
Sleeping
Add ProtGPT2 model
Browse files- hexviz/app.py +2 -1
- hexviz/attention.py +32 -11
hexviz/app.py
CHANGED
@@ -8,13 +8,14 @@ from hexviz.attention import Model, ModelType, get_attention_pairs
|
|
8 |
st.title("Attention Visualization on proteins")
|
9 |
|
10 |
"""
|
11 |
-
Visualize attention weights on protein structures for the protein language models
|
12 |
Pick a PDB ID, layer and head to visualize attention.
|
13 |
"""
|
14 |
|
15 |
|
16 |
# Define list of model types
|
17 |
models = [
|
|
|
18 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
19 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
20 |
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
|
|
8 |
st.title("Attention Visualization on proteins")
|
9 |
|
10 |
"""
|
11 |
+
Visualize attention weights on protein structures for the protein language models ProtGPT2, TAPE-BERT and ZymCTRL.
|
12 |
Pick a PDB ID, layer and head to visualize attention.
|
13 |
"""
|
14 |
|
15 |
|
16 |
# Define list of model types
|
17 |
models = [
|
18 |
+
Model(name=ModelType.ProtGPT2, layers=36, heads=20),
|
19 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
20 |
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
|
21 |
# Model(name=ModelType.PROT_T5, layers=24, heads=32),
|
hexviz/attention.py
CHANGED
@@ -15,6 +15,7 @@ class ModelType(str, Enum):
|
|
15 |
TAPE_BERT = "TAPE-BERT"
|
16 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
17 |
ZymCTRL = "ZymCTRL"
|
|
|
18 |
|
19 |
|
20 |
class Model:
|
@@ -77,6 +78,13 @@ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
|
77 |
model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
|
78 |
return tokenizer, model
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
@st.cache
|
81 |
def get_attention(
|
82 |
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
@@ -84,6 +92,8 @@ def get_attention(
|
|
84 |
"""
|
85 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
86 |
"""
|
|
|
|
|
87 |
if model_type == ModelType.TAPE_BERT:
|
88 |
tokenizer, model = get_tape_bert()
|
89 |
token_idxs = tokenizer.encode(sequence).tolist()
|
@@ -95,7 +105,6 @@ def get_attention(
|
|
95 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
96 |
|
97 |
elif model_type == ModelType.ZymCTRL:
|
98 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
99 |
tokenizer, model = get_zymctrl()
|
100 |
inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
|
101 |
attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
|
@@ -110,20 +119,32 @@ def get_attention(
|
|
110 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
111 |
attentions = attention_stacked
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
elif model_type == ModelType.PROT_T5:
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
|
119 |
|
120 |
-
|
121 |
-
|
122 |
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
127 |
else:
|
128 |
raise ValueError(f"Model {model_type} not supported")
|
129 |
|
|
|
15 |
TAPE_BERT = "TAPE-BERT"
|
16 |
PROT_T5 = "prot_t5_xl_half_uniref50-enc"
|
17 |
ZymCTRL = "ZymCTRL"
|
18 |
+
ProtGPT2 = "ProtGPT2"
|
19 |
|
20 |
|
21 |
class Model:
|
|
|
78 |
model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
|
79 |
return tokenizer, model
|
80 |
|
81 |
+
@st.cache
|
82 |
+
def get_protgpt2() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
|
83 |
+
device = torch.device('cuda')
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
|
85 |
+
model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
|
86 |
+
return tokenizer, model
|
87 |
+
|
88 |
@st.cache
|
89 |
def get_attention(
|
90 |
sequence: str, model_type: ModelType = ModelType.TAPE_BERT
|
|
|
92 |
"""
|
93 |
Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
|
94 |
"""
|
95 |
+
|
96 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
97 |
if model_type == ModelType.TAPE_BERT:
|
98 |
tokenizer, model = get_tape_bert()
|
99 |
token_idxs = tokenizer.encode(sequence).tolist()
|
|
|
105 |
attentions = torch.stack([attention.squeeze(0) for attention in attentions])
|
106 |
|
107 |
elif model_type == ModelType.ZymCTRL:
|
|
|
108 |
tokenizer, model = get_zymctrl()
|
109 |
inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
|
110 |
attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
|
|
|
119 |
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
120 |
attentions = attention_stacked
|
121 |
|
122 |
+
elif model_type == ModelType.ProtGPT2:
|
123 |
+
tokenizer, model = get_protgpt2()
|
124 |
+
input_ids = tokenizer.encode(input, return_tensors='pt').to(device)
|
125 |
+
with torch.no_grad():
|
126 |
+
outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
|
127 |
+
attentions = outputs.attentions
|
128 |
+
|
129 |
+
# torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
|
130 |
+
attention_squeezed = [torch.squeeze(attention) for attention in attentions]
|
131 |
+
# ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
|
132 |
+
attention_stacked = torch.stack([attention for attention in attention_squeezed])
|
133 |
+
attentions = attention_stacked
|
134 |
+
|
135 |
elif model_type == ModelType.PROT_T5:
|
136 |
+
# Introduce white-space between all amino acids
|
137 |
+
sequence = " ".join(sequence)
|
138 |
+
# tokenize sequences and pad up to the longest sequence in the batch
|
139 |
+
ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
|
|
|
140 |
|
141 |
+
input_ids = torch.tensor(ids['input_ids']).to(device)
|
142 |
+
attention_mask = torch.tensor(ids['attention_mask']).to(device)
|
143 |
|
144 |
+
with torch.no_grad():
|
145 |
+
attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
|
146 |
|
147 |
+
tokenizer, model = get_protT5()
|
148 |
else:
|
149 |
raise ValueError(f"Model {model_type} not supported")
|
150 |
|