Spaces:
Runtime error
Runtime error
roni
commited on
Commit
•
6509a73
1
Parent(s):
8a1f601
supporting multiple indexes
Browse files- app.py +61 -20
- get_index.py +10 -2
app.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from concurrency import execute_multithread
|
4 |
-
from get_index import
|
5 |
from protein_viz import get_gene_name, get_protein_name, render_html
|
6 |
|
7 |
index_repo = "ronig/siamese_protein_index"
|
8 |
model_repo = "ronig/protein_search_engine"
|
9 |
-
|
|
|
10 |
|
11 |
|
12 |
-
def search_and_display(seq, n_res):
|
13 |
n_res = int(limit_n_results(n_res))
|
|
|
14 |
search_res = engine.search_by_sequence(seq, n=n_res)
|
15 |
results_options = update_dropdown_menu(search_res)
|
16 |
formatted_search_results = format_search_results(search_res)
|
@@ -24,12 +26,18 @@ def limit_n_results(n):
|
|
24 |
def update_dropdown_menu(search_res):
|
25 |
choices = []
|
26 |
for row in search_res:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
def format_search_results(raw_search_results):
|
@@ -44,6 +52,22 @@ def format_search_results(raw_search_results):
|
|
44 |
|
45 |
|
46 |
def format_search_result(raw_result):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
prot = raw_result["pdb_name"]
|
48 |
chain = raw_result["chain_id"]
|
49 |
value = raw_result["score"]
|
@@ -54,16 +78,28 @@ def format_search_result(raw_result):
|
|
54 |
return key, value
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def switch_viz(new_choice):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
|
69 |
with gr.Blocks() as demo:
|
@@ -79,10 +115,15 @@ with gr.Blocks() as demo:
|
|
79 |
with gr.Column():
|
80 |
with gr.Row():
|
81 |
with gr.Column():
|
82 |
-
seq_input = gr.Textbox(
|
83 |
-
value="APTMPPPLPP", label="Input Sequence"
|
84 |
-
)
|
85 |
n_results = gr.Number(5, label="N Results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
search_button = gr.Button("Search", variant="primary")
|
87 |
search_results = gr.Label(num_top_classes=20, label="Search Results")
|
88 |
viz_header = gr.Markdown("## Visualization", visible=False)
|
@@ -103,7 +144,7 @@ with gr.Blocks() as demo:
|
|
103 |
)
|
104 |
search_button.click(
|
105 |
search_and_display,
|
106 |
-
inputs=[seq_input, n_results],
|
107 |
outputs=[search_results, results_selector],
|
108 |
)
|
109 |
results_selector.change(
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from concurrency import execute_multithread
|
4 |
+
from get_index import get_engines
|
5 |
from protein_viz import get_gene_name, get_protein_name, render_html
|
6 |
|
7 |
index_repo = "ronig/siamese_protein_index"
|
8 |
model_repo = "ronig/protein_search_engine"
|
9 |
+
engines = get_engines(index_repo, model_repo)
|
10 |
+
available_indexes = list(engines.keys())
|
11 |
|
12 |
|
13 |
+
def search_and_display(seq, n_res, index_selection):
|
14 |
n_res = int(limit_n_results(n_res))
|
15 |
+
engine = engines[index_selection]
|
16 |
search_res = engine.search_by_sequence(seq, n=n_res)
|
17 |
results_options = update_dropdown_menu(search_res)
|
18 |
formatted_search_results = format_search_results(search_res)
|
|
|
26 |
def update_dropdown_menu(search_res):
|
27 |
choices = []
|
28 |
for row in search_res:
|
29 |
+
if "pdb_name" in row and "chain_id" in row:
|
30 |
+
choice = ".".join([row["pdb_name"], row["chain_id"]])
|
31 |
+
choices.append(choice)
|
32 |
+
if choices:
|
33 |
+
update = gr.Dropdown.update(
|
34 |
+
choices=choices, interactive=True, value=choices[0], visible=True
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
update = gr.Dropdown.update(
|
38 |
+
choices=choices, interactive=True, visible=False, value=None
|
39 |
+
)
|
40 |
+
return update
|
41 |
|
42 |
|
43 |
def format_search_results(raw_search_results):
|
|
|
52 |
|
53 |
|
54 |
def format_search_result(raw_result):
|
55 |
+
is_pdb = "pdb_name" in raw_result
|
56 |
+
if is_pdb:
|
57 |
+
key, value = parse_pdb_search_result(raw_result)
|
58 |
+
else:
|
59 |
+
key, value = parse_fasta_search_result(raw_result)
|
60 |
+
return key, value
|
61 |
+
|
62 |
+
|
63 |
+
def parse_fasta_search_result(raw_result):
|
64 |
+
gene = parse_gene_from_fasta_entry(raw_result["description"])
|
65 |
+
key = f"Gene: {gene}"
|
66 |
+
value = raw_result["score"]
|
67 |
+
return key, value
|
68 |
+
|
69 |
+
|
70 |
+
def parse_pdb_search_result(raw_result):
|
71 |
prot = raw_result["pdb_name"]
|
72 |
chain = raw_result["chain_id"]
|
73 |
value = raw_result["score"]
|
|
|
78 |
return key, value
|
79 |
|
80 |
|
81 |
+
def parse_gene_from_fasta_entry(description):
|
82 |
+
after = description.split("GN=")[1]
|
83 |
+
gene = after.split(" ")[0]
|
84 |
+
return gene
|
85 |
+
|
86 |
+
|
87 |
def switch_viz(new_choice):
|
88 |
+
if new_choice is None:
|
89 |
+
html = ""
|
90 |
+
title_update = gr.Markdown.update(visible=False)
|
91 |
+
description_update = gr.Markdown.update(value=None, visible=False)
|
92 |
+
else:
|
93 |
+
choice_parts = new_choice.split(".")
|
94 |
+
pdb_id, chain = choice_parts[0], choice_parts[1]
|
95 |
+
title_update = gr.Markdown.update(visible=True)
|
96 |
+
protein_name = get_protein_name(pdb_id)
|
97 |
|
98 |
+
new_value = f"""**PDB Title**: {protein_name}"""
|
99 |
|
100 |
+
description_update = gr.Markdown.update(value=new_value, visible=True)
|
101 |
+
html = render_html(pdb_id=pdb_id, chain=chain)
|
102 |
+
return html, title_update, description_update
|
103 |
|
104 |
|
105 |
with gr.Blocks() as demo:
|
|
|
115 |
with gr.Column():
|
116 |
with gr.Row():
|
117 |
with gr.Column():
|
118 |
+
seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence")
|
|
|
|
|
119 |
n_results = gr.Number(5, label="N Results")
|
120 |
+
index_selector = gr.Dropdown(
|
121 |
+
choices=available_indexes,
|
122 |
+
value=available_indexes[0],
|
123 |
+
multiselect=False,
|
124 |
+
visible=True,
|
125 |
+
label="Index",
|
126 |
+
)
|
127 |
search_button = gr.Button("Search", variant="primary")
|
128 |
search_results = gr.Label(num_top_classes=20, label="Search Results")
|
129 |
viz_header = gr.Markdown("## Visualization", visible=False)
|
|
|
144 |
)
|
145 |
search_button.click(
|
146 |
search_and_display,
|
147 |
+
inputs=[seq_input, n_results, index_selector],
|
148 |
outputs=[search_results, results_selector],
|
149 |
)
|
150 |
results_selector.change(
|
get_index.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import sys
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
from huggingface_hub import snapshot_download
|
@@ -6,10 +8,11 @@ from huggingface_hub import snapshot_download
|
|
6 |
from credentials import get_token
|
7 |
|
8 |
|
9 |
-
def
|
10 |
index_path = Path(
|
11 |
snapshot_download(index_repo, use_auth_token=get_token(), repo_type="dataset")
|
12 |
)
|
|
|
13 |
local_arch_path = Path(
|
14 |
snapshot_download(model_repo, use_auth_token=get_token(), repo_type="model")
|
15 |
)
|
@@ -18,4 +21,9 @@ def get_engine(index_repo: str, model_repo: str):
|
|
18 |
ProteinSearchEngine,
|
19 |
)
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
import sys
|
3 |
+
from glob import glob
|
4 |
from pathlib import Path
|
5 |
|
6 |
from huggingface_hub import snapshot_download
|
|
|
8 |
from credentials import get_token
|
9 |
|
10 |
|
11 |
+
def get_engines(index_repo: str, model_repo: str):
|
12 |
index_path = Path(
|
13 |
snapshot_download(index_repo, use_auth_token=get_token(), repo_type="dataset")
|
14 |
)
|
15 |
+
|
16 |
local_arch_path = Path(
|
17 |
snapshot_download(model_repo, use_auth_token=get_token(), repo_type="model")
|
18 |
)
|
|
|
21 |
ProteinSearchEngine,
|
22 |
)
|
23 |
|
24 |
+
subindex_paths = glob(str(index_path / "*/"))
|
25 |
+
engines = {}
|
26 |
+
for subindex_path in subindex_paths:
|
27 |
+
subindex_name = os.path.basename(subindex_path).title()
|
28 |
+
engines[subindex_name] = ProteinSearchEngine(data_path=Path(subindex_path))
|
29 |
+
return engines
|