File size: 3,418 Bytes
e873d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import dataclasses
import math
from typing import List, Optional

import torch
from pymilvus import MilvusClient, connections
from transformers import AutoModel, AutoTokenizer

from credentials import get_token


@dataclasses.dataclass
class MilvusParams:
    uri: str
    token: str
    db_name: str
    collection_name: str


class ProteinSearchEngine:
    n_dims = 128
    dist_metric = "euclidean"
    max_lengths = (30, 300)

    def __init__(self, milvus_params: MilvusParams, model_repo: str):
        self.model_repo = model_repo
        self.milvus_params = milvus_params
        connections.connect(
            "default",
            uri=milvus_params.uri,
            token=milvus_params.token,
            db_name=milvus_params.db_name,
        )
        self.client = MilvusClient(uri=milvus_params.uri, token=milvus_params.token)
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_repo, use_auth_token=get_token()
        )
        self.model = AutoModel.from_pretrained(
            self.model_repo, use_auth_token=get_token(), trust_remote_code=True
        )
        self.model.eval()

    def search_by_sequence(self, sequence: str, n: int, organism: Optional[str] = None):
        max_length = self.max_lengths[0]
        vec = self._embed_sequence(max_length, sequence)
        response = self.search(vec, n_results=n, is_peptide=False, organism=organism)
        search_results = self._format_search_results(response)
        return search_results

    def _embed_sequence(self, max_length, sequence):
        encoded = self.tokenizer.encode_plus(
            sequence,
            add_special_tokens=True,
            truncation=True,
            max_length=max_length,
            padding="max_length",
            return_tensors="pt",
        )
        with torch.no_grad():
            vec = (
                self.model.forward1(encoded.to(self.model.device))
                .squeeze()
                .cpu()
                .numpy()
            )
        return vec

    def _format_search_results(self, response):
        search_results = []
        max_dist = math.sqrt(2 * self.n_dims)
        for res in response:
            entry = res["entity"]
            dist = math.sqrt(res["distance"])
            entry["dist"] = dist
            entry["score"] = (max_dist - dist) / max_dist
            search_results.append(entry)
        return search_results

    def search(
        self,
        vec: List[float],
        n_results: int,
        is_peptide: bool,
        organism: Optional[str] = None,
    ):
        is_peptide = bool(is_peptide)
        filter_str = f"is_peptide == {is_peptide}"
        if organism is not None:
            filter_str += f" and organism == '{organism}'"

        results = self.client.search(
            collection_name=self.milvus_params.collection_name,
            data=[vec],
            limit=n_results,
            output_fields=[
                "genes",
                "uniprot_id",
                "pdb_name",
                "chain_id",
                "is_peptide",
                "organism",
            ],
            filter=filter_str,
        )
        return results[0]

    def get_organisms(self):
        res = self.client.query(
            collection_name=self.milvus_params.collection_name,
            output_fields=["organism"],
            filter="entry_id > 0",
        )
        return res