import dataclasses import math from typing import List, Optional import torch from pymilvus import MilvusClient, connections from transformers import AutoModel, AutoTokenizer from credentials import get_token @dataclasses.dataclass class MilvusParams: uri: str token: str db_name: str collection_name: str class ProteinSearchEngine: n_dims = 128 dist_metric = "euclidean" max_lengths = (30, 300) def __init__(self, milvus_params: MilvusParams, model_repo: str): self.model_repo = model_repo self.milvus_params = milvus_params connections.connect( "default", uri=milvus_params.uri, token=milvus_params.token, db_name=milvus_params.db_name, ) self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token) self.tokenizer = AutoTokenizer.from_pretrained( self.model_repo, use_auth_token=get_token() ) self.model = AutoModel.from_pretrained( self.model_repo, use_auth_token=get_token(), trust_remote_code=True ) self.model.eval() def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None): max_length = self.max_lengths[0] vec = self._embed_sequence(max_length, sequence) response = self.search(vec, n_results=n, is_peptide=False, organism=organism) search_results = self._format_search_results(response) return search_results def _embed_sequence(self, max_length, sequence): encoded = self.tokenizer.encode_plus( sequence, add_special_tokens=True, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) with torch.no_grad(): vec = ( self.model.forward1(encoded.to(self.model.device)) .squeeze() .cpu() .numpy() ) return vec def _format_search_results(self, response): search_results = [] max_dist = math.sqrt(2 * self.n_dims) for res in response: entry = res["entity"] dist = math.sqrt(res["distance"]) entry["dist"] = dist entry["score"] = (max_dist - dist) / max_dist search_results.append(entry) return search_results def search( self, vec: List[float], n_results: int, is_peptide: bool, organism: Optional[str] = None, ): is_peptide = bool(is_peptide) filter_str = f"is_peptide == {is_peptide}" if organism is not None: filter_str += f" and organism == '{organism}'" results = self.client.search( collection_name=self.milvus_params.collection_name, data=[vec], limit=n_results, output_fields=[ "genes", "uniprot_id", "pdb_name", "chain_id", "is_peptide", "organism", ], filter=filter_str, ) return results[0] def get_organisms(self): res = self.client.query( collection_name=self.milvus_params.collection_name, output_fields=["organism"], filter="entry_id > 0", ) return res