roni commited on
Commit
e873d33
1 Parent(s): 7217cfd

App switched to use Milvus instead of Annoy

Browse files
Files changed (8) hide show
  1. Makefile +1 -1
  2. app.py +22 -13
  3. get_index.py +0 -36
  4. index_list.py +11 -0
  5. pylintrc +0 -20
  6. requirements-dev.txt +1 -1
  7. requirements.txt +2 -1
  8. search_engine.py +113 -0
Makefile CHANGED
@@ -12,4 +12,4 @@ check-formatting:
12
  venv/bin/black --check .
13
 
14
  lint-python:
15
- venv/bin/pylint --rcfile=pylintrc .
 
12
  venv/bin/black --check .
13
 
14
  lint-python:
15
+ venv/bin/ruff .
app.py CHANGED
@@ -1,31 +1,40 @@
1
  import collections
 
2
  from typing import Dict, List
3
 
4
  import gradio as gr
5
 
6
- from get_index import get_engines
7
  from protein_viz import get_pdb_title, render_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- index_repo = "ronig/protein_index"
10
- model_repo = "ronig/protein_search_engine"
11
- engines = get_engines(index_repo, model_repo)
12
- available_indexes = list(engines.keys())
13
- app_description = """
14
- # Protein Binding Search Engine
15
- This application enables a quick protein-peptide binding search based on sequences.
16
- You can use it to search the full [PDB](https://www.rcsb.org/) database or in a specific organism genome.
17
- """
18
  max_results = 1000
19
  choice_sep = " | "
20
  max_seq_length = 50
21
 
22
 
23
  def search_and_display(seq, max_res, index_selection):
24
- n_search_res = 10000
25
  _validate_sequence_length(seq)
26
  max_res = int(limit_n_results(max_res))
27
- engine = engines[index_selection]
28
- search_res = engine.search_by_sequence(seq, n=n_search_res)
 
 
 
29
  agg_search_results = aggregate_search_results(search_res, max_res)
30
  formatted_search_results = format_search_results(agg_search_results)
31
  results_options = update_dropdown_menu(agg_search_results)
 
1
  import collections
2
+ import os
3
  from typing import Dict, List
4
 
5
  import gradio as gr
6
 
7
+ from index_list import read_index_list
8
  from protein_viz import get_pdb_title, render_html
9
+ from search_engine import MilvusParams, ProteinSearchEngine
10
+
11
+ model_repo = "ronig/protein_biencoder"
12
+
13
+ available_indexes = read_index_list()
14
+ engine = ProteinSearchEngine(
15
+ milvus_params=MilvusParams(
16
+ uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com",
17
+ token=os.environ.get("MILVUS_TOKEN"),
18
+ db_name="Protein",
19
+ collection_name="Peptriever",
20
+ ),
21
+ model_repo=model_repo,
22
+ )
23
 
 
 
 
 
 
 
 
 
 
24
  max_results = 1000
25
  choice_sep = " | "
26
  max_seq_length = 50
27
 
28
 
29
  def search_and_display(seq, max_res, index_selection):
30
+ n_search_res = 1024
31
  _validate_sequence_length(seq)
32
  max_res = int(limit_n_results(max_res))
33
+ if index_selection == "All Species":
34
+ index_selection = None
35
+ search_res = engine.search_by_sequence(
36
+ seq, n=n_search_res, organism=index_selection
37
+ )
38
  agg_search_results = aggregate_search_results(search_res, max_res)
39
  formatted_search_results = format_search_results(agg_search_results)
40
  results_options = update_dropdown_menu(agg_search_results)
get_index.py DELETED
@@ -1,36 +0,0 @@
1
- import os.path
2
- import sys
3
- from glob import glob
4
- from pathlib import Path
5
-
6
- from huggingface_hub import snapshot_download
7
-
8
- from credentials import get_token
9
-
10
-
11
- def get_engines(index_repo: str, model_repo: str):
12
- index_path = Path(
13
- snapshot_download(index_repo, use_auth_token=get_token(), repo_type="dataset")
14
- )
15
-
16
- local_arch_path = Path(
17
- snapshot_download(model_repo, use_auth_token=get_token(), repo_type="model")
18
- )
19
- sys.path.append(str(local_arch_path))
20
- from protein_index import ( # pylint: disable=import-error,import-outside-toplevel
21
- ProteinSearchEngine,
22
- ProteinIndexError,
23
- )
24
-
25
- subindex_paths = glob(str(index_path / "*/"))
26
- engines = {}
27
- for subindex_path in subindex_paths:
28
- subindex_name = os.path.basename(subindex_path)
29
- try:
30
- engine = ProteinSearchEngine(data_path=Path(subindex_path))
31
- if len(engine) > 1000:
32
- engines[subindex_name] = engine
33
- except ProteinIndexError:
34
- ...
35
-
36
- return engines
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
index_list.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+
4
+ def read_index_list():
5
+ here = os.path.dirname(__file__)
6
+ fname = os.path.join(here, "available_organisms.txt")
7
+ indexes = ["All Species"]
8
+ with open(fname) as f:
9
+ for index in f:
10
+ indexes.append(index.strip())
11
+ return indexes
pylintrc DELETED
@@ -1,20 +0,0 @@
1
- [MESSAGES CONTROL]
2
- disable=missing-docstring,invalid-name,logging-fstring-interpolation
3
-
4
- [DESIGN]
5
- min-public-methods=1
6
-
7
- [FORMAT]
8
- max-line-length=88
9
-
10
- [SIMILARITIES]
11
- min-similarity-lines=10
12
-
13
- [TYPECHECK]
14
-
15
- [MASTER]
16
- init-hook=import sys; sys.path.append(".")
17
- extension-pkg-whitelist=pydantic,cassandra
18
- generated-members=torch.*,cv2.*,np.random.*
19
- ignore-patterns=setup,py,tasks.py
20
- max-args=6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements-dev.txt CHANGED
@@ -1,5 +1,5 @@
1
  pytest
2
- pylint
3
  black
4
  mypy
5
  huggingface_hub
 
1
  pytest
2
+ ruff
3
  black
4
  mypy
5
  huggingface_hub
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
  transformers
3
  annoy
4
- mygene
 
 
1
  torch
2
  transformers
3
  annoy
4
+ mygene
5
+ pymilvus
search_engine.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import math
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ from pymilvus import MilvusClient, connections
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+ from credentials import get_token
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class MilvusParams:
14
+ uri: str
15
+ token: str
16
+ db_name: str
17
+ collection_name: str
18
+
19
+
20
+ class ProteinSearchEngine:
21
+ n_dims = 128
22
+ dist_metric = "euclidean"
23
+ max_lengths = (30, 300)
24
+
25
+ def __init__(self, milvus_params: MilvusParams, model_repo: str):
26
+ self.model_repo = model_repo
27
+ self.milvus_params = milvus_params
28
+ connections.connect(
29
+ "default",
30
+ uri=milvus_params.uri,
31
+ token=milvus_params.token,
32
+ db_name=milvus_params.db_name,
33
+ )
34
+ self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token)
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ self.model_repo, use_auth_token=get_token()
37
+ )
38
+ self.model = AutoModel.from_pretrained(
39
+ self.model_repo, use_auth_token=get_token(), trust_remote_code=True
40
+ )
41
+ self.model.eval()
42
+
43
+ def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None):
44
+ max_length = self.max_lengths[0]
45
+ vec = self._embed_sequence(max_length, sequence)
46
+ response = self.search(vec, n_results=n, is_peptide=False, organism=organism)
47
+ search_results = self._format_search_results(response)
48
+ return search_results
49
+
50
+ def _embed_sequence(self, max_length, sequence):
51
+ encoded = self.tokenizer.encode_plus(
52
+ sequence,
53
+ add_special_tokens=True,
54
+ truncation=True,
55
+ max_length=max_length,
56
+ padding="max_length",
57
+ return_tensors="pt",
58
+ )
59
+ with torch.no_grad():
60
+ vec = (
61
+ self.model.forward1(encoded.to(self.model.device))
62
+ .squeeze()
63
+ .cpu()
64
+ .numpy()
65
+ )
66
+ return vec
67
+
68
+ def _format_search_results(self, response):
69
+ search_results = []
70
+ max_dist = math.sqrt(2 * self.n_dims)
71
+ for res in response:
72
+ entry = res["entity"]
73
+ dist = math.sqrt(res["distance"])
74
+ entry["dist"] = dist
75
+ entry["score"] = (max_dist - dist) / max_dist
76
+ search_results.append(entry)
77
+ return search_results
78
+
79
+ def search(
80
+ self,
81
+ vec: List[float],
82
+ n_results: int,
83
+ is_peptide: bool,
84
+ organism: Optional[str] = None,
85
+ ):
86
+ is_peptide = bool(is_peptide)
87
+ filter_str = f"is_peptide == {is_peptide}"
88
+ if organism is not None:
89
+ filter_str += f" and organism == '{organism}'"
90
+
91
+ results = self.client.search(
92
+ collection_name=self.milvus_params.collection_name,
93
+ data=[vec],
94
+ limit=n_results,
95
+ output_fields=[
96
+ "genes",
97
+ "uniprot_id",
98
+ "pdb_name",
99
+ "chain_id",
100
+ "is_peptide",
101
+ "organism",
102
+ ],
103
+ filter=filter_str,
104
+ )
105
+ return results[0]
106
+
107
+ def get_organisms(self):
108
+ res = self.client.query(
109
+ collection_name=self.milvus_params.collection_name,
110
+ output_fields=["organism"],
111
+ filter="entry_id > 0",
112
+ )
113
+ return res