aksell commited on
Commit
466a8f2
1 Parent(s): b7ab123

Implement get_attention for tape BERT

Browse files
Files changed (4) hide show
  1. poetry.lock +118 -1
  2. protention/attention.py +30 -11
  3. pyproject.toml +1 -0
  4. tests/test_attention.py +10 -1
poetry.lock CHANGED
@@ -171,6 +171,38 @@ category = "main"
171
  optional = false
172
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  [[package]]
175
  name = "cachetools"
176
  version = "5.3.0"
@@ -572,6 +604,14 @@ MarkupSafe = ">=2.0"
572
  [package.extras]
573
  i18n = ["Babel (>=2.7)"]
574
 
 
 
 
 
 
 
 
 
575
  [[package]]
576
  name = "jsonpointer"
577
  version = "2.3"
@@ -749,6 +789,14 @@ category = "main"
749
  optional = false
750
  python-versions = "*"
751
 
 
 
 
 
 
 
 
 
752
  [[package]]
753
  name = "markdown-it-py"
754
  version = "2.2.0"
@@ -1474,6 +1522,36 @@ pygments = ">=2.13.0,<3.0.0"
1474
  [package.extras]
1475
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
1476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1477
  [[package]]
1478
  name = "semver"
1479
  version = "2.13.0"
@@ -1613,6 +1691,37 @@ python-versions = ">=3.8"
1613
  [package.dependencies]
1614
  mpmath = ">=0.19"
1615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1616
  [[package]]
1617
  name = "terminado"
1618
  version = "0.17.1"
@@ -1983,7 +2092,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
1983
  [metadata]
1984
  lock-version = "1.1"
1985
  python-versions = "^3.10"
1986
- content-hash = "c748285bd150fadef69123d60f0b4ad96d99715916c7e1ab30214132749f8aed"
1987
 
1988
  [metadata.files]
1989
  altair = []
@@ -2027,6 +2136,8 @@ beautifulsoup4 = []
2027
  biopython = []
2028
  bleach = []
2029
  blinker = []
 
 
2030
  cachetools = []
2031
  certifi = []
2032
  cffi = []
@@ -2071,6 +2182,7 @@ ipywidgets = []
2071
  isoduration = []
2072
  jedi = []
2073
  jinja2 = []
 
2074
  jsonpointer = []
2075
  jsonschema = []
2076
  jupyter-client = []
@@ -2082,6 +2194,7 @@ jupyter-server-terminals = []
2082
  jupyterlab-pygments = []
2083
  jupyterlab-widgets = []
2084
  lit = []
 
2085
  markdown-it-py = []
2086
  markupsafe = []
2087
  matplotlib-inline = []
@@ -2205,6 +2318,8 @@ requests = []
2205
  rfc3339-validator = []
2206
  rfc3986-validator = []
2207
  rich = []
 
 
2208
  semver = []
2209
  send2trash = [
2210
  {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
@@ -2225,6 +2340,8 @@ stack-data = []
2225
  stmol = []
2226
  streamlit = []
2227
  sympy = []
 
 
2228
  terminado = []
2229
  tinycss2 = []
2230
  tokenizers = []
 
171
  optional = false
172
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
173
 
174
+ [[package]]
175
+ name = "boto3"
176
+ version = "1.26.95"
177
+ description = "The AWS SDK for Python"
178
+ category = "main"
179
+ optional = false
180
+ python-versions = ">= 3.7"
181
+
182
+ [package.dependencies]
183
+ botocore = ">=1.29.95,<1.30.0"
184
+ jmespath = ">=0.7.1,<2.0.0"
185
+ s3transfer = ">=0.6.0,<0.7.0"
186
+
187
+ [package.extras]
188
+ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
189
+
190
+ [[package]]
191
+ name = "botocore"
192
+ version = "1.29.95"
193
+ description = "Low-level, data-driven core of boto 3."
194
+ category = "main"
195
+ optional = false
196
+ python-versions = ">= 3.7"
197
+
198
+ [package.dependencies]
199
+ jmespath = ">=0.7.1,<2.0.0"
200
+ python-dateutil = ">=2.1,<3.0.0"
201
+ urllib3 = ">=1.25.4,<1.27"
202
+
203
+ [package.extras]
204
+ crt = ["awscrt (==0.16.9)"]
205
+
206
  [[package]]
207
  name = "cachetools"
208
  version = "5.3.0"
 
604
  [package.extras]
605
  i18n = ["Babel (>=2.7)"]
606
 
607
+ [[package]]
608
+ name = "jmespath"
609
+ version = "1.0.1"
610
+ description = "JSON Matching Expressions"
611
+ category = "main"
612
+ optional = false
613
+ python-versions = ">=3.7"
614
+
615
  [[package]]
616
  name = "jsonpointer"
617
  version = "2.3"
 
789
  optional = false
790
  python-versions = "*"
791
 
792
+ [[package]]
793
+ name = "lmdb"
794
+ version = "1.4.0"
795
+ description = "Universal Python binding for the LMDB 'Lightning' Database"
796
+ category = "main"
797
+ optional = false
798
+ python-versions = "*"
799
+
800
  [[package]]
801
  name = "markdown-it-py"
802
  version = "2.2.0"
 
1522
  [package.extras]
1523
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
1524
 
1525
+ [[package]]
1526
+ name = "s3transfer"
1527
+ version = "0.6.0"
1528
+ description = "An Amazon S3 Transfer Manager"
1529
+ category = "main"
1530
+ optional = false
1531
+ python-versions = ">= 3.7"
1532
+
1533
+ [package.dependencies]
1534
+ botocore = ">=1.12.36,<2.0a.0"
1535
+
1536
+ [package.extras]
1537
+ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
1538
+
1539
+ [[package]]
1540
+ name = "scipy"
1541
+ version = "1.9.3"
1542
+ description = "Fundamental algorithms for scientific computing in Python"
1543
+ category = "main"
1544
+ optional = false
1545
+ python-versions = ">=3.8"
1546
+
1547
+ [package.dependencies]
1548
+ numpy = ">=1.18.5,<1.26.0"
1549
+
1550
+ [package.extras]
1551
+ test = ["pytest", "pytest-cov", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack"]
1552
+ doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-panels (>=0.5.2)", "matplotlib (>2)", "numpydoc", "sphinx-tabs"]
1553
+ dev = ["mypy", "typing-extensions", "pycodestyle", "flake8"]
1554
+
1555
  [[package]]
1556
  name = "semver"
1557
  version = "2.13.0"
 
1691
  [package.dependencies]
1692
  mpmath = ">=0.19"
1693
 
1694
+ [[package]]
1695
+ name = "tape-proteins"
1696
+ version = "0.5"
1697
+ description = "Repostory of Protein Benchmarking and Modeling"
1698
+ category = "main"
1699
+ optional = false
1700
+ python-versions = "*"
1701
+
1702
+ [package.dependencies]
1703
+ biopython = "*"
1704
+ boto3 = "*"
1705
+ filelock = "*"
1706
+ lmdb = "*"
1707
+ requests = "*"
1708
+ scipy = "*"
1709
+ tensorboardX = "*"
1710
+ tqdm = "*"
1711
+
1712
+ [[package]]
1713
+ name = "tensorboardx"
1714
+ version = "2.6"
1715
+ description = "TensorBoardX lets you watch Tensors Flow without Tensorflow"
1716
+ category = "main"
1717
+ optional = false
1718
+ python-versions = "*"
1719
+
1720
+ [package.dependencies]
1721
+ numpy = "*"
1722
+ packaging = "*"
1723
+ protobuf = ">=3.8.0,<4"
1724
+
1725
  [[package]]
1726
  name = "terminado"
1727
  version = "0.17.1"
 
2092
  [metadata]
2093
  lock-version = "1.1"
2094
  python-versions = "^3.10"
2095
+ content-hash = "ad6054ae4a119d961e9941f135489d1b89310303aefc27d3132fbd1ed1c35a0f"
2096
 
2097
  [metadata.files]
2098
  altair = []
 
2136
  biopython = []
2137
  bleach = []
2138
  blinker = []
2139
+ boto3 = []
2140
+ botocore = []
2141
  cachetools = []
2142
  certifi = []
2143
  cffi = []
 
2182
  isoduration = []
2183
  jedi = []
2184
  jinja2 = []
2185
+ jmespath = []
2186
  jsonpointer = []
2187
  jsonschema = []
2188
  jupyter-client = []
 
2194
  jupyterlab-pygments = []
2195
  jupyterlab-widgets = []
2196
  lit = []
2197
+ lmdb = []
2198
  markdown-it-py = []
2199
  markupsafe = []
2200
  matplotlib-inline = []
 
2318
  rfc3339-validator = []
2319
  rfc3986-validator = []
2320
  rich = []
2321
+ s3transfer = []
2322
+ scipy = []
2323
  semver = []
2324
  send2trash = [
2325
  {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
 
2340
  stmol = []
2341
  streamlit = []
2342
  sympy = []
2343
+ tape-proteins = []
2344
+ tensorboardx = []
2345
  terminado = []
2346
  tinycss2 = []
2347
  tokenizers = []
protention/attention.py CHANGED
@@ -1,11 +1,16 @@
 
1
  from io import StringIO
2
  from urllib import request
3
 
4
  import torch
5
  from Bio.PDB import PDBParser, Polypeptide, Structure
 
6
  from transformers import T5EncoderModel, T5Tokenizer
7
 
8
 
 
 
 
9
  def get_structure(pdb_code: str) -> Structure:
10
  """
11
  Get structure from PDB
@@ -46,9 +51,14 @@ def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
46
 
47
  return tokenizer, model
48
 
 
 
 
 
 
49
 
50
  def get_attention(
51
- pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
52
  ):
53
  """
54
  Get attention from T5
@@ -57,13 +67,22 @@ def get_attention(
57
  structure = get_structure(pdb_code)
58
  # Get list of sequences
59
  sequences = get_sequences(structure)
60
-
61
- # get model
62
- tokenizer, model = get_protT5()
63
-
64
- # call model
65
- ## Get sequence
66
-
67
- # get attention
68
-
69
- # extract attention
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
  from io import StringIO
3
  from urllib import request
4
 
5
  import torch
6
  from Bio.PDB import PDBParser, Polypeptide, Structure
7
+ from tape import ProteinBertModel, TAPETokenizer
8
  from transformers import T5EncoderModel, T5Tokenizer
9
 
10
 
11
+ class Model(str, Enum):
12
+ tape_bert = "bert-base"
13
+
14
  def get_structure(pdb_code: str) -> Structure:
15
  """
16
  Get structure from PDB
 
51
 
52
  return tokenizer, model
53
 
54
+ def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
55
+ tokenizer = TAPETokenizer()
56
+ model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
57
+ return tokenizer, model
58
+
59
 
60
  def get_attention(
61
+ pdb_code: str, model: Model = Model.tape_bert
62
  ):
63
  """
64
  Get attention from T5
 
67
  structure = get_structure(pdb_code)
68
  # Get list of sequences
69
  sequences = get_sequences(structure)
70
+ # TODO handle multiple sequences
71
+ sequence = sequences[0]
72
+
73
+ match model:
74
+ case model.tape_bert:
75
+ tokenizer, model = get_tape_bert()
76
+ token_idxs = tokenizer.encode(sequence).tolist()
77
+ inputs = torch.tensor(token_idxs).unsqueeze(0)
78
+ with torch.no_grad():
79
+ attns = model(inputs)[-1]
80
+ # Remove attention from <CLS> (first) and <SEP> (last) token
81
+ attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
+ attns = torch.stack([attn.squeeze(0) for attn in attns])
83
+ case model.prot_T5:
84
+ # Space separate sequences
85
+ sequences = [" ".join(sequence) for sequence in sequences]
86
+ tokenizer, model = get_protT5()
87
+
88
+ return attns
pyproject.toml CHANGED
@@ -12,6 +12,7 @@ biopython = "^1.81"
12
  transformers = "^4.27.1"
13
  torch = "^2.0.0"
14
  sentencepiece = "^0.1.97"
 
15
 
16
  [tool.poetry.dev-dependencies]
17
  pytest = "^7.2.2"
 
12
  transformers = "^4.27.1"
13
  torch = "^2.0.0"
14
  sentencepiece = "^0.1.97"
15
+ tape-proteins = "^0.5"
16
 
17
  [tool.poetry.dev-dependencies]
18
  pytest = "^7.2.2"
tests/test_attention.py CHANGED
@@ -1,7 +1,9 @@
 
1
  from Bio.PDB.Structure import Structure
2
  from transformers import T5EncoderModel, T5Tokenizer
3
 
4
- from protention.attention import get_protT5, get_sequences, get_structure
 
5
 
6
 
7
  def test_get_structure():
@@ -33,3 +35,10 @@ def test_get_protT5():
33
 
34
  assert isinstance(tokenizer, T5Tokenizer)
35
  assert isinstance(model, T5EncoderModel)
 
 
 
 
 
 
 
 
1
+ import torch
2
  from Bio.PDB.Structure import Structure
3
  from transformers import T5EncoderModel, T5Tokenizer
4
 
5
+ from protention.attention import (Model, get_attention, get_protT5,
6
+ get_sequences, get_structure)
7
 
8
 
9
  def test_get_structure():
 
35
 
36
  assert isinstance(tokenizer, T5Tokenizer)
37
  assert isinstance(model, T5EncoderModel)
38
+
39
+ def test_get_attention_tape():
40
+
41
+ result = get_attention("1AKE", model=Model.tape_bert)
42
+
43
+ assert result is not None
44
+ assert result.shape == torch.Size([12,12,456,456])