aksell commited on
Commit
eb9ae1f
1 Parent(s): e07549d

Split attention.py into attention.py and models.py

Browse files
hexviz/app.py CHANGED
@@ -3,7 +3,8 @@ import stmol
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
- from hexviz.attention import Model, ModelType, get_attention_pairs
 
7
 
8
  st.title("Attention Visualization on proteins")
9
 
@@ -12,8 +13,6 @@ Visualize attention weights on protein structures for the protein language model
12
  Pick a PDB ID, layer and head to visualize attention.
13
  """
14
 
15
-
16
- # Define list of model types
17
  models = [
18
  # Model(name=ModelType.ProtGPT2, layers=36, heads=20),
19
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
 
3
  import streamlit as st
4
  from stmol import showmol
5
 
6
+ from hexviz.attention import get_attention_pairs
7
+ from hexviz.models import Model, ModelType
8
 
9
  st.title("Attention Visualization on proteins")
10
 
 
13
  Pick a PDB ID, layer and head to visualize attention.
14
  """
15
 
 
 
16
  models = [
17
  # Model(name=ModelType.ProtGPT2, layers=36, heads=20),
18
  Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
hexviz/attention.py CHANGED
@@ -6,24 +6,10 @@ from urllib import request
6
  import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
- from tape import ProteinBertModel, TAPETokenizer
10
- from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
11
- T5Tokenizer)
12
 
13
 
14
- class ModelType(str, Enum):
15
- TAPE_BERT = "TAPE-BERT"
16
- PROT_T5 = "prot_t5_xl_half_uniref50-enc"
17
- ZymCTRL = "ZymCTRL"
18
- ProtGPT2 = "ProtGPT2"
19
-
20
-
21
- class Model:
22
- def __init__(self, name, layers, heads):
23
- self.name: ModelType = name
24
- self.layers: int = layers
25
- self.heads: int = heads
26
-
27
  @st.cache
28
  def get_structure(pdb_code: str) -> Structure:
29
  """
@@ -51,41 +37,6 @@ def get_sequences(structure: Structure) -> List[str]:
51
  sequences.append("".join(list(residues_single_letter)))
52
  return sequences
53
 
54
- @st.cache
55
- def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
56
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
57
- tokenizer = T5Tokenizer.from_pretrained(
58
- "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
59
- )
60
-
61
- model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
62
- device
63
- )
64
-
65
- model.full() if device == "cpu" else model.half()
66
-
67
- return tokenizer, model
68
-
69
- @st.cache
70
- def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
71
- tokenizer = TAPETokenizer()
72
- model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
73
- return tokenizer, model
74
-
75
- @st.cache
76
- def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
77
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
78
- tokenizer = AutoTokenizer.from_pretrained('nferruz/ZymCTRL')
79
- model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
80
- return tokenizer, model
81
-
82
- @st.cache
83
- def get_protgpt2() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
84
- device = torch.device('cuda')
85
- tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
86
- model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
87
- return tokenizer, model
88
-
89
  @st.cache
90
  def get_attention(
91
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
 
6
  import streamlit as st
7
  import torch
8
  from Bio.PDB import PDBParser, Polypeptide, Structure
9
+ from models import (ModelType, get_protgpt2, get_protT5, get_tape_bert,
10
+ get_zymctrl)
 
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @st.cache
14
  def get_structure(pdb_code: str) -> Structure:
15
  """
 
37
  sequences.append("".join(list(residues_single_letter)))
38
  return sequences
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @st.cache
41
  def get_attention(
42
  sequence: str, model_type: ModelType = ModelType.TAPE_BERT
hexviz/models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Tuple
3
+
4
+ import streamlit as st
5
+ import torch
6
+ from tape import ProteinBertModel, TAPETokenizer
7
+ from transformers import (AutoTokenizer, GPT2LMHeadModel, T5EncoderModel,
8
+ T5Tokenizer)
9
+
10
+
11
+ class ModelType(str, Enum):
12
+ TAPE_BERT = "TAPE-BERT"
13
+ PROT_T5 = "prot_t5_xl_half_uniref50-enc"
14
+ ZymCTRL = "ZymCTRL"
15
+ ProtGPT2 = "ProtGPT2"
16
+
17
+
18
+ class Model:
19
+ def __init__(self, name, layers, heads):
20
+ self.name: ModelType = name
21
+ self.layers: int = layers
22
+ self.heads: int = heads
23
+
24
+ @st.cache
25
+ def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
26
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27
+ tokenizer = T5Tokenizer.from_pretrained(
28
+ "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
29
+ )
30
+
31
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
32
+ device
33
+ )
34
+
35
+ model.full() if device == "cpu" else model.half()
36
+
37
+ return tokenizer, model
38
+
39
+ @st.cache
40
+ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
41
+ tokenizer = TAPETokenizer()
42
+ model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
43
+ return tokenizer, model
44
+
45
+ @st.cache
46
+ def get_zymctrl() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
47
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
48
+ tokenizer = AutoTokenizer.from_pretrained('nferruz/ZymCTRL')
49
+ model = GPT2LMHeadModel.from_pretrained('nferruz/ZymCTRL').to(device)
50
+ return tokenizer, model
51
+
52
+ @st.cache
53
+ def get_protgpt2() -> Tuple[AutoTokenizer, GPT2LMHeadModel]:
54
+ device = torch.device('cuda')
55
+ tokenizer = AutoTokenizer.from_pretrained('nferruz/ProtGPT2')
56
+ model = GPT2LMHeadModel.from_pretrained('nferruz/ProtGPT2').to(device)
57
+ return tokenizer, model
tests/test_attention.py CHANGED
@@ -1,11 +1,8 @@
1
  import torch
2
  from Bio.PDB.Structure import Structure
3
- from transformers import (GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel,
4
- T5Tokenizer)
5
 
6
- from hexviz.attention import (ModelType, get_attention, get_protT5,
7
- get_sequences, get_structure, get_zymctrl,
8
- unidirectional_sum_filtered)
9
 
10
 
11
  def test_get_structure():
@@ -27,27 +24,6 @@ def test_get_sequences():
27
  A, B = sequences
28
  assert A[:3] == ["M", "R", "I"]
29
 
30
- def test_get_protT5():
31
- result = get_protT5()
32
-
33
- assert result is not None
34
- assert isinstance(result, tuple)
35
-
36
- tokenizer, model = result
37
-
38
- assert isinstance(tokenizer, T5Tokenizer)
39
- assert isinstance(model, T5EncoderModel)
40
-
41
- def test_get_zymctrl():
42
- result = get_zymctrl()
43
-
44
- assert result is not None
45
- assert isinstance(result, tuple)
46
-
47
- tokenizer, model = result
48
-
49
- assert isinstance(tokenizer, GPT2TokenizerFast)
50
- assert isinstance(model, GPT2LMHeadModel)
51
 
52
  def test_get_attention_zymctrl():
53
 
 
1
  import torch
2
  from Bio.PDB.Structure import Structure
 
 
3
 
4
+ from hexviz.attention import (ModelType, get_attention, get_sequences,
5
+ get_structure, unidirectional_sum_filtered)
 
6
 
7
 
8
  def test_get_structure():
 
24
  A, B = sequences
25
  assert A[:3] == ["M", "R", "I"]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def test_get_attention_zymctrl():
29
 
tests/test_models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import (GPT2LMHeadModel, GPT2TokenizerFast, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+ from hexviz.models import get_protT5, get_zymctrl
6
+
7
+
8
+ def test_get_protT5():
9
+ result = get_protT5()
10
+
11
+ assert result is not None
12
+ assert isinstance(result, tuple)
13
+
14
+ tokenizer, model = result
15
+
16
+ assert isinstance(tokenizer, T5Tokenizer)
17
+ assert isinstance(model, T5EncoderModel)
18
+
19
+ def test_get_zymctrl():
20
+ result = get_zymctrl()
21
+
22
+ assert result is not None
23
+ assert isinstance(result, tuple)
24
+
25
+ tokenizer, model = result
26
+
27
+ assert isinstance(tokenizer, GPT2TokenizerFast)
28
+ assert isinstance(model, GPT2LMHeadModel)