aksell commited on
Commit
cfba77f
1 Parent(s): 5b6d16d

Add ZymCTRL model

Browse files

Needs some tweaking. Massive attention to the same residue
makes weird disks in the attention bars. It is also noticeably
slow on a CPU only instance.

Files changed (3) hide show
  1. hexviz/app.py +1 -0
  2. hexviz/attention.py +49 -6
  3. tests/test_attention.py +35 -4
hexviz/app.py CHANGED
@@ -9,6 +9,7 @@ st.title("pLM Attention Visualization")
9
 
10
  # Define list of model types
11
  models = [
 
12
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
13
  # Model(name=ModelType.PROT_T5, layers=24, heads=32),
14
  ]
 
9
 
10
  # Define list of model types
11
  models = [
12
+ Model(name=ModelType.ZymCTRL, layers=36, heads=16),
13
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
14
  # Model(name=ModelType.PROT_T5, layers=24, heads=32),
15
  ]
hexviz/attention.py CHANGED
@@ -7,12 +7,14 @@ import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
  from tape import ProteinBertModel, TAPETokenizer
10
- from transformers import T5EncoderModel, T5Tokenizer
 
11
 
12
 
13
  class ModelType(str, Enum):
14
  TAPE_BERT = "bert-base"
15
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
 
16
 
17
 
18
  class Model:
@@ -45,7 +47,7 @@ def get_sequences(structure: Structure) -> List[str]:
45
  # TODO ask if using protein_letters_3to1_extended makes sense
46
  residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
47
 
48
- sequences.append(list(residues_single_letter))
49
  return sequences
50
 
51
  @st.cache
@@ -69,18 +71,59 @@ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
69
  return tokenizer, model
70
 
71
  @st.cache
 
 
 
 
 
 
72
  def get_attention(
73
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
74
  ):
 
 
 
75
  if model_type == ModelType.TAPE_BERT:
76
  tokenizer, model = get_tape_bert()
77
  token_idxs = tokenizer.encode(sequence).tolist()
78
  inputs = torch.tensor(token_idxs).unsqueeze(0)
79
  with torch.no_grad():
80
- attns = model(inputs)[-1]
81
  # Remove attention from <CLS> (first) and <SEP> (last) token
82
- attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
83
- attns = torch.stack([attn.squeeze(0) for attn in attns])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  elif model_type == ModelType.PROT_T5:
85
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
86
  # Introduce white-space between all amino acids
@@ -98,7 +141,7 @@ def get_attention(
98
  else:
99
  raise ValueError(f"Model {model_type} not supported")
100
 
101
- return attns
102
 
103
  def unidirectional_sum_filtered(attention, layer, head, threshold):
104
  num_layers, num_heads, seq_len, _ = attention.shape
 
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
  from tape import ProteinBertModel, TAPETokenizer
10
+ from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
11
+ T5Tokenizer)
12
 
13
 
14
  class ModelType(str, Enum):
15
  TAPE_BERT = "bert-base"
16
  PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
+ ZymCTRL = "zymctrl"
18
 
19
 
20
  class Model:
 
47
  # TODO ask if using protein_letters_3to1_extended makes sense
48
  residues_single_letter = map(lambda x: Polypeptide.protein_letters_3to1.get(x, "X"), residues)
49
 
50
+ sequences.append("".join(list(residues_single_letter)))
51
  return sequences
52
 
53
  @st.cache
 
71
  return tokenizer, model
72
 
73
  @st.cache
74
+ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
75
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76
+ tokenizer = AutoTokenizer.from_pretrained('nferruz/ZymCTRL')
77
+ model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
78
+ return tokenizer, model
79
+
80
  def get_attention(
81
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
82
  ):
83
+ """
84
+ Returns a tensor of shape [n_layers, n_heads, n_res, n_res] with attention weights
85
+ """
86
  if model_type == ModelType.TAPE_BERT:
87
  tokenizer, model = get_tape_bert()
88
  token_idxs = tokenizer.encode(sequence).tolist()
89
  inputs = torch.tensor(token_idxs).unsqueeze(0)
90
  with torch.no_grad():
91
+ attentions = model(inputs)[-1]
92
  # Remove attention from <CLS> (first) and <SEP> (last) token
93
+ attentions = [attention[:, :, 1:-1, 1:-1] for attention in attentions]
94
+ attentions = torch.stack([attention.squeeze(0) for attention in attentions])
95
+
96
+ elif model_type == ModelType.ZymCTRL:
97
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
+ tokenizer, model = get_zymctrl()
99
+ inputs = tokenizer(sequence, return_tensors='pt').input_ids.to(device)
100
+ attention_mask = tokenizer(sequence, return_tensors='pt').attention_mask.to(device)
101
+
102
+ with torch.no_grad():
103
+ outputs = model(inputs, attention_mask=attention_mask, output_attentions=True)
104
+ attentions = outputs.attentions
105
+ if attentions[0].shape[-1] == attentions[0].shape[-2] == 1:
106
+ reshaped = [attention.view(attention.shape[1], attention.shape[0]) for attention in attentions]
107
+ n_residues = reshaped[0].shape[-1]
108
+ n_heads = 16
109
+ i,j = torch.triu_indices(n_residues, n_residues)
110
+
111
+ attentions_symmetric = []
112
+ # Make symmetric attention matrix
113
+ for attention in reshaped:
114
+ x = torch.zeros(n_heads, n_residues, n_residues)
115
+ x[:,i,j] = attention
116
+ x[:,j,i] = attention
117
+ attentions_symmetric.append(x)
118
+ attentions = torch.stack([attention for attention in attentions_symmetric])
119
+ else:
120
+ # torch.Size([1, n_heads, n_res, n_res]) -> torch.Size([n_heads, n_res, n_res])
121
+ attention_squeezed = [torch.squeeze(attention) for attention in attentions]
122
+
123
+ # ([n_heads, n_res, n_res]*n_layers) -> [n_layers, n_heads, n_res, n_res]
124
+ attention_stacked = torch.stack([attention for attention in attention_squeezed])
125
+ attentions = attention_stacked
126
+
127
  elif model_type == ModelType.PROT_T5:
128
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
129
  # Introduce white-space between all amino acids
 
141
  else:
142
  raise ValueError(f"Model {model_type} not supported")
143
 
144
+ return attentions
145
 
146
  def unidirectional_sum_filtered(attention, layer, head, threshold):
147
  num_layers, num_heads, seq_len, _ = attention.shape
tests/test_attention.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
  from Bio.PDB.Structure import Structure
3
- from transformers import T5EncoderModel, T5Tokenizer
 
4
 
5
  from hexviz.attention import (ModelType, get_attention, get_protT5,
6
- get_sequences, get_structure,
7
  unidirectional_sum_filtered)
8
 
9
 
@@ -37,12 +38,42 @@ def test_get_protT5():
37
  assert isinstance(tokenizer, T5Tokenizer)
38
  assert isinstance(model, T5EncoderModel)
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def test_get_attention_tape():
 
 
41
 
42
- result = get_attention("1AKE", model=ModelType.TAPE_BERT)
43
 
44
  assert result is not None
45
- assert result.shape == torch.Size([12,12,456,456])
46
 
47
  def test_get_unidirection_sum_filtered():
48
  # 1 head, 1 layer, 4 residues long attention tensor
 
1
  import torch
2
  from Bio.PDB.Structure import Structure
3
+ from transformers import (GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel,
4
+ T5Tokenizer)
5
 
6
  from hexviz.attention import (ModelType, get_attention, get_protT5,
7
+ get_sequences, get_structure, get_zymctrl,
8
  unidirectional_sum_filtered)
9
 
10
 
 
38
  assert isinstance(tokenizer, T5Tokenizer)
39
  assert isinstance(model, T5EncoderModel)
40
 
41
+ def test_get_zymctrl():
42
+ result = get_zymctrl()
43
+
44
+ assert result is not None
45
+ assert isinstance(result, tuple)
46
+
47
+ tokenizer, model = result
48
+
49
+ assert isinstance(tokenizer, GPT2TokenizerFast)
50
+ assert isinstance(model, GPT2LMHeadModel)
51
+
52
+ def test_get_attention_zymctrl():
53
+
54
+ result = get_attention("GGG", model_type=ModelType.ZymCTRL)
55
+
56
+ assert result is not None
57
+ assert result.shape == torch.Size([36,16,3,3])
58
+
59
+ def test_get_attention_zymctrl_long_chain():
60
+ structure = get_structure(pdb_code="6A5J") # 13 residues long
61
+
62
+ sequences = get_sequences(structure)
63
+
64
+ result = get_attention(sequences[0], model_type=ModelType.ZymCTRL)
65
+
66
+ assert result is not None
67
+ assert result.shape == torch.Size([36,16,13,13])
68
+
69
  def test_get_attention_tape():
70
+ structure = get_structure(pdb_code="6A5J") # 13 residues long
71
+ sequences = get_sequences(structure)
72
 
73
+ result = get_attention(sequences[0], model_type=ModelType.TAPE_BERT)
74
 
75
  assert result is not None
76
+ assert result.shape == torch.Size([12,12,13,13])
77
 
78
  def test_get_unidirection_sum_filtered():
79
  # 1 head, 1 layer, 4 residues long attention tensor