aksell commited on
Commit
2338bb4
1 Parent(s): 1431daf

Add ProtGPT2 model

Browse files
Files changed (2) hide show
  1. hexviz/app.py +2 -1
  2. hexviz/attention.py +32 -11
hexviz/app.py CHANGED
@@ -8,13 +8,14 @@ from hexviz.attention import Model, ModelType, get_attention_pairs
8
  st.title("Attention Visualization on proteins")
9
 
10
  """
11
- Visualize attention weights on protein structures for the protein language models ZymCTRL and TAPE-BERT.
12
  Pick a PDB ID, layer and head to visualize attention.
13
  """
14
 
15
 
16
  # Define list of model types
17
  models = [
 
18
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
19
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
20
  # Model(name=ModelType.PROT_T5, layers=24, heads=32),
 
8
  st.title("Attention Visualization on proteins")
9
 
10
  """
11
+ Visualize attention weights on protein structures for the protein language models ProtGPT2, TAPE-BERT and ZymCTRL.
12
  Pick a PDB ID, layer and head to visualize attention.
13
  """
14
 
15
 
16
  # Define list of model types
17
  models = [
18
+ Model(name=ModelType.ProtGPT2, layers=36, heads=20),
19
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
20
  Model(name=ModelType.ZymCTRL, layers=36, heads=16),
21
  # Model(name=ModelType.PROT_T5, layers=24, heads=32),
hexviz/attention.py CHANGED
@@ -15,6 +15,7 @@ class ModelType(str, Enum):
15
  TAPE_BERT = "TAPE-BERT"
16
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
  ZymCTRL = "ZymCTRL"
 
18
 
19
 
20
  class Model:
@@ -77,6 +78,13 @@ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
77
  model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
78
  return tokenizer, model
79
 
 
 
 
 
 
 
 
80
  @st.cache
81
  def get_attention(
82
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
@@ -84,6 +92,8 @@ def get_attention(
84
  """
85
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
86
  """
 
 
87
  if model_type == ModelType.TAPE_BERT:
88
  tokenizer, model = get_tape_bert()
89
  token_idxs = tokenizer.encode(sequence).tolist()
@@ -95,7 +105,6 @@ def get_attention(
95
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
96
 
97
  elif model_type == ModelType.ZymCTRL:
98
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
99
  tokenizer, model = get_zymctrl()
100
  inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
101
  attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
@@ -110,20 +119,32 @@ def get_attention(
110
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
111
  attentions = attention_stacked
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  elif model_type == ModelType.PROT_T5:
114
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
115
- # Introduce white-space between all amino acids
116
- sequence = " ".join(sequence)
117
- # tokenize sequences and pad up to the longest sequence in the batch
118
- ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
119
 
120
- input_ids = torch.tensor(ids['input_ids']).to(device)
121
- attention_mask = torch.tensor(ids['attention_mask']).to(device)
122
 
123
- with torch.no_grad():
124
- attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
125
 
126
- tokenizer, model = get_protT5()
127
  else:
128
  raise ValueError(f"Model {model_type} not supported")
129
 
 
15
  TAPE_BERT = "TAPE-BERT"
16
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
  ZymCTRL = "ZymCTRL"
18
+ ProtGPT2 = "ProtGPT2"
19
 
20
 
21
  class Model:
 
78
  model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
79
  return tokenizer, model
80
 
81
+ @st.cache
82
+ def get_protgpt2() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
83
+ device = torch.device('cuda')
84
+ tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
85
+ model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
86
+ return tokenizer, model
87
+
88
  @st.cache
89
  def get_attention(
90
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
 
92
  """
93
  Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
94
  """
95
+
96
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
97
  if model_type == ModelType.TAPE_BERT:
98
  tokenizer, model = get_tape_bert()
99
  token_idxs = tokenizer.encode(sequence).tolist()
 
105
  attentions = torch.stack([attention.squeeze(0) for attention in attentions])
106
 
107
  elif model_type == ModelType.ZymCTRL:
 
108
  tokenizer, model = get_zymctrl()
109
  inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
110
  attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
 
119
  attention_stacked = torch.stack([attention for attention in attention_squeezed])
120
  attentions = attention_stacked
121
 
122
+ elif model_type == ModelType.ProtGPT2:
123
+ tokenizer, model = get_protgpt2()
124
+ input_ids = tokenizer.encode(input, return_tensors='pt').to(device)
125
+ with torch.no_grad():
126
+ outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
127
+ attentions = outputs.attentions
128
+
129
+ # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
130
+ attention_squeezed = [torch.squeeze(attention) for attention in attentions]
131
+ # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
132
+ attention_stacked = torch.stack([attention for attention in attention_squeezed])
133
+ attentions = attention_stacked
134
+
135
  elif model_type == ModelType.PROT_T5:
136
+ # Introduce white-space between all amino acids
137
+ sequence = " ".join(sequence)
138
+ # tokenize sequences and pad up to the longest sequence in the batch
139
+ ids = tokenizer.encode_plus(sequence, add_special_tokens=True, padding="longest")
 
140
 
141
+ input_ids = torch.tensor(ids['input_ids']).to(device)
142
+ attention_mask = torch.tensor(ids['attention_mask']).to(device)
143
 
144
+ with torch.no_grad():
145
+ attns = model(input_ids=input_ids,attention_mask=attention_mask)[-1]
146
 
147
+ tokenizer, model = get_protT5()
148
  else:
149
  raise ValueError(f"Model {model_type} not supported")
150