File size: 6,414 Bytes
33eb5d4
e873d33
33eb5d4
 
27e2770
 
e873d33
33eb5d4
e873d33
 
 
 
 
 
 
 
 
 
 
 
 
 
27e2770
e1f535f
33eb5d4
e103bde
27e2770
1694358
33eb5d4
e873d33
e103bde
33eb5d4
e873d33
 
 
 
 
33eb5d4
 
 
79b6488
 
 
e103bde
 
 
 
 
79b6488
5e7a3eb
79b6488
 
33eb5d4
6f068fd
33eb5d4
4a6ff7a
 
6f068fd
4a6ff7a
6f068fd
 
 
 
 
 
 
33eb5d4
 
 
 
 
 
 
 
6f068fd
33eb5d4
 
 
6f068fd
 
33eb5d4
 
 
 
 
b2a3d53
6f068fd
33eb5d4
6f068fd
 
 
 
 
 
 
 
6509a73
33eb5d4
6509a73
d46e61e
 
 
 
 
 
6509a73
d46e61e
6509a73
d46e61e
 
 
 
 
 
6509a73
 
79b6488
 
6509a73
1d11011
 
 
1694358
 
3380f3c
1694358
 
1d11011
 
 
79b6488
6509a73
 
d46e61e
 
 
 
6509a73
33eb5d4
 
d46e61e
33eb5d4
b2a3d53
33eb5d4
b2a3d53
d46e61e
 
 
6509a73
 
dfecb5b
 
27e2770
cdbbabd
dfecb5b
 
 
6509a73
9608c6c
6509a73
 
d48c285
6509a73
 
 
 
b2a3d53
5e7a3eb
d46e61e
5e7a3eb
79b6488
 
b2a3d53
 
 
 
79b6488
1d25e2a
dfecb5b
79b6488
b2a3d53
 
 
94ce4d7
b092b28
dfecb5b
 
 
6509a73
b2a3d53
79b6488
 
b2a3d53
dfecb5b
 
b2a3d53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import collections
import os
from typing import Dict, List

import gradio as gr

from index_list import read_index_list
from protein_viz import get_pdb_title, render_html
from search_engine import MilvusParams, ProteinSearchEngine

model_repo = "ronig/protein_biencoder"

available_indexes = read_index_list()
engine = ProteinSearchEngine(
    milvus_params=MilvusParams(
        uri="https://in03-ddab8e9a5a09fcc.api.gcp-us-west1.zillizcloud.com",
        token=os.environ.get("MILVUS_TOKEN"),
        db_name="Protein",
        collection_name="Peptriever",
    ),
    model_repo=model_repo,
)

max_results = 1000
choice_sep = " | "
max_seq_length = 50


def search_and_display(seq, max_res, index_selection):
    n_search_res = 1024
    _validate_sequence_length(seq)
    max_res = int(limit_n_results(max_res))
    if index_selection == "All Species":
        index_selection = None
    search_res = engine.search_by_sequence(
        seq, n=n_search_res, organism=index_selection
    )
    agg_search_results = aggregate_search_results(search_res, max_res)
    formatted_search_results = format_search_results(agg_search_results)
    results_options = update_dropdown_menu(agg_search_results)
    return formatted_search_results, results_options


def _validate_sequence_length(seq):
    if len(seq) > max_seq_length:
        raise gr.Error("Only peptide input is currently supported")


def limit_n_results(n):
    return max(min(n, max_results), 1)


def aggregate_search_results(raw_results: List[dict], max_res: int) -> Dict[str, dict]:
    aggregated_by_uniprot = collections.defaultdict(list)
    for raw_result in raw_results:
        entry = select_keys(
            raw_result,
            keys=["pdb_name", "chain_id", "score", "organism", "uniprot_id", "genes"],
        )
        uniprot_id = raw_result["uniprot_id"]

        if uniprot_id is not None:
            aggregated_by_uniprot[uniprot_id].append(entry)
            if len(aggregated_by_uniprot) >= max_res:
                return dict(aggregated_by_uniprot)
    return dict(aggregated_by_uniprot)


def select_keys(d: dict, keys: List[str]):
    return {key: d[key] for key in keys}


def format_search_results(agg_search_results):
    formatted_search_results = {}
    for uniprot_id, entries in agg_search_results.items():
        entry = entries[0]
        organism = entry["organism"]
        score = entry["score"]
        genes = entry["genes"]
        key = f"Uniprot ID: {uniprot_id} | Organism: {organism} | Gene Names: {genes}"
        formatted_search_results[key] = score
    return formatted_search_results


def update_dropdown_menu(agg_search_res):
    choices = []
    for uniprot_id, entries in agg_search_res.items():
        for entry in entries:
            choice = choice_sep.join(
                [
                    uniprot_id,
                    entry["pdb_name"],
                    entry["chain_id"],
                    entry["genes"] or "",
                ]
            )
            choices.append(choice)

    if choices:
        update = gr.update(
            gr.Dropdown.get_component_class_id(),
            choices=choices,
            interactive=True,
            value=choices[0],
            visible=True,
        )

    else:
        update = gr.update(
            gr.Dropdown.get_component_class_id(),
            choices=choices,
            interactive=True,
            visible=False,
            value=None,
        )
    return update


def parse_pdb_search_result(raw_result):
    prot = raw_result["pdb_name"]
    chain = raw_result["chain_id"]
    value = raw_result["score"]
    gene_names = raw_result["genes"]
    species = raw_result["organism"]
    key = f"PDB: {prot}.{chain}"
    if gene_names is not None:
        key += f" | Genes: {gene_names} | Organism: {species}"
    return key, value


def switch_viz(new_choice):
    if new_choice is None:
        html = ""
        title_update = gr.update(gr.Markdown.get_component_class_id(), visible=False)
        description_update = gr.update(
            gr.Markdown.get_component_class_id(), value=None, visible=False
        )
    else:
        choice_parts = new_choice.split(choice_sep)
        pdb_id, chain = choice_parts[1:3]
        title_update = gr.update(gr.Markdown.get_component_class_id(), visible=True)
        pdb_title = get_pdb_title(pdb_id)

        new_value = f"""**PDB Title**: {pdb_title}"""

        description_update = gr.update(
            gr.Markdown.get_component_class_id(), value=new_value, visible=True
        )
        html = render_html(pdb_id=pdb_id, chain=chain)
    return html, title_update, description_update


with gr.Blocks() as demo:
    with gr.Column():
        with gr.Column():
            with gr.Row():
                with gr.Column():
                    seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence")
                    n_results = gr.Number(10, label="N Results")
                    index_selector = gr.Dropdown(
                        choices=available_indexes,
                        value="All Species",
                        multiselect=False,
                        visible=True,
                        label="Index",
                    )
                    search_button = gr.Button("Search", variant="primary")
                search_results = gr.Label(
                    num_top_classes=max_results, label="Search Results", scale=2
                )
            viz_header = gr.Markdown("## Visualization", visible=False)
            results_selector = gr.Dropdown(
                choices=[],
                multiselect=False,
                visible=False,
                label="Visualized Search Result",
            )
            viz_body = gr.Markdown("", visible=False)
            protein_viz = gr.HTML(
                value=render_html(pdb_id=None, chain=None),
                label="Protein Visualization",
            )
            gr.Examples(
                ["APTMPPPLPP", "KFLIYQMECSTMIFGL", "PHFAMPPIHEDHLE", "AEERIISLD"],
                inputs=[seq_input],
            )
    search_button.click(
        search_and_display,
        inputs=[seq_input, n_results, index_selector],
        outputs=[search_results, results_selector],
    )
    results_selector.change(
        switch_viz, inputs=results_selector, outputs=[protein_viz, viz_header, viz_body]
    )

if __name__ == "__main__":
    demo.launch()