protein_binding_search / search_engine.py
roni
App switched to use Milvus instead of Annoy
e873d33
import dataclasses
import math
from typing import List, Optional
import torch
from pymilvus import MilvusClient, connections
from transformers import AutoModel, AutoTokenizer
from credentials import get_token
@dataclasses.dataclass
class MilvusParams:
uri: str
token: str
db_name: str
collection_name: str
class ProteinSearchEngine:
n_dims = 128
dist_metric = "euclidean"
max_lengths = (30, 300)
def __init__(self, milvus_params: MilvusParams, model_repo: str):
self.model_repo = model_repo
self.milvus_params = milvus_params
connections.connect(
"default",
uri=milvus_params.uri,
token=milvus_params.token,
db_name=milvus_params.db_name,
)
self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_repo, use_auth_token=get_token()
)
self.model = AutoModel.from_pretrained(
self.model_repo, use_auth_token=get_token(), trust_remote_code=True
)
self.model.eval()
def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None):
max_length = self.max_lengths[0]
vec = self._embed_sequence(max_length, sequence)
response = self.search(vec, n_results=n, is_peptide=False, organism=organism)
search_results = self._format_search_results(response)
return search_results
def _embed_sequence(self, max_length, sequence):
encoded = self.tokenizer.encode_plus(
sequence,
add_special_tokens=True,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
)
with torch.no_grad():
vec = (
self.model.forward1(encoded.to(self.model.device))
.squeeze()
.cpu()
.numpy()
)
return vec
def _format_search_results(self, response):
search_results = []
max_dist = math.sqrt(2 * self.n_dims)
for res in response:
entry = res["entity"]
dist = math.sqrt(res["distance"])
entry["dist"] = dist
entry["score"] = (max_dist - dist) / max_dist
search_results.append(entry)
return search_results
def search(
self,
vec: List[float],
n_results: int,
is_peptide: bool,
organism: Optional[str] = None,
):
is_peptide = bool(is_peptide)
filter_str = f"is_peptide == {is_peptide}"
if organism is not None:
filter_str += f" and organism == '{organism}'"
results = self.client.search(
collection_name=self.milvus_params.collection_name,
data=[vec],
limit=n_results,
output_fields=[
"genes",
"uniprot_id",
"pdb_name",
"chain_id",
"is_peptide",
"organism",
],
filter=filter_str,
)
return results[0]
def get_organisms(self):
res = self.client.query(
collection_name=self.milvus_params.collection_name,
output_fields=["organism"],
filter="entry_id > 0",
)
return res