roni commited on
Commit
6509a73
1 Parent(s): 8a1f601

supporting multiple indexes

Browse files
Files changed (2) hide show
  1. app.py +61 -20
  2. get_index.py +10 -2
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import gradio as gr
2
 
3
  from concurrency import execute_multithread
4
- from get_index import get_engine
5
  from protein_viz import get_gene_name, get_protein_name, render_html
6
 
7
  index_repo = "ronig/siamese_protein_index"
8
  model_repo = "ronig/protein_search_engine"
9
- engine = get_engine(index_repo, model_repo)
 
10
 
11
 
12
- def search_and_display(seq, n_res):
13
  n_res = int(limit_n_results(n_res))
 
14
  search_res = engine.search_by_sequence(seq, n=n_res)
15
  results_options = update_dropdown_menu(search_res)
16
  formatted_search_results = format_search_results(search_res)
@@ -24,12 +26,18 @@ def limit_n_results(n):
24
  def update_dropdown_menu(search_res):
25
  choices = []
26
  for row in search_res:
27
- choice = ".".join([row["pdb_name"], row["chain_id"]])
28
- choices.append(choice)
29
-
30
- return gr.Dropdown.update(
31
- choices=choices, interactive=True, value=choices[0], visible=True
32
- )
 
 
 
 
 
 
33
 
34
 
35
  def format_search_results(raw_search_results):
@@ -44,6 +52,22 @@ def format_search_results(raw_search_results):
44
 
45
 
46
  def format_search_result(raw_result):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  prot = raw_result["pdb_name"]
48
  chain = raw_result["chain_id"]
49
  value = raw_result["score"]
@@ -54,16 +78,28 @@ def format_search_result(raw_result):
54
  return key, value
55
 
56
 
 
 
 
 
 
 
57
  def switch_viz(new_choice):
58
- choice_parts = new_choice.split(".")
59
- pdb_id, chain = choice_parts[0], choice_parts[1]
60
- title_update = gr.Markdown.update(visible=True)
61
- protein_name = get_protein_name(pdb_id)
 
 
 
 
 
62
 
63
- new_value = f"""**PDB Title**: {protein_name}"""
64
 
65
- description_update = gr.Markdown.update(value=new_value, visible=True)
66
- return render_html(pdb_id=pdb_id, chain=chain), title_update, description_update
 
67
 
68
 
69
  with gr.Blocks() as demo:
@@ -79,10 +115,15 @@ with gr.Blocks() as demo:
79
  with gr.Column():
80
  with gr.Row():
81
  with gr.Column():
82
- seq_input = gr.Textbox(
83
- value="APTMPPPLPP", label="Input Sequence"
84
- )
85
  n_results = gr.Number(5, label="N Results")
 
 
 
 
 
 
 
86
  search_button = gr.Button("Search", variant="primary")
87
  search_results = gr.Label(num_top_classes=20, label="Search Results")
88
  viz_header = gr.Markdown("## Visualization", visible=False)
@@ -103,7 +144,7 @@ with gr.Blocks() as demo:
103
  )
104
  search_button.click(
105
  search_and_display,
106
- inputs=[seq_input, n_results],
107
  outputs=[search_results, results_selector],
108
  )
109
  results_selector.change(
 
1
  import gradio as gr
2
 
3
  from concurrency import execute_multithread
4
+ from get_index import get_engines
5
  from protein_viz import get_gene_name, get_protein_name, render_html
6
 
7
  index_repo = "ronig/siamese_protein_index"
8
  model_repo = "ronig/protein_search_engine"
9
+ engines = get_engines(index_repo, model_repo)
10
+ available_indexes = list(engines.keys())
11
 
12
 
13
+ def search_and_display(seq, n_res, index_selection):
14
  n_res = int(limit_n_results(n_res))
15
+ engine = engines[index_selection]
16
  search_res = engine.search_by_sequence(seq, n=n_res)
17
  results_options = update_dropdown_menu(search_res)
18
  formatted_search_results = format_search_results(search_res)
 
26
  def update_dropdown_menu(search_res):
27
  choices = []
28
  for row in search_res:
29
+ if "pdb_name" in row and "chain_id" in row:
30
+ choice = ".".join([row["pdb_name"], row["chain_id"]])
31
+ choices.append(choice)
32
+ if choices:
33
+ update = gr.Dropdown.update(
34
+ choices=choices, interactive=True, value=choices[0], visible=True
35
+ )
36
+ else:
37
+ update = gr.Dropdown.update(
38
+ choices=choices, interactive=True, visible=False, value=None
39
+ )
40
+ return update
41
 
42
 
43
  def format_search_results(raw_search_results):
 
52
 
53
 
54
  def format_search_result(raw_result):
55
+ is_pdb = "pdb_name" in raw_result
56
+ if is_pdb:
57
+ key, value = parse_pdb_search_result(raw_result)
58
+ else:
59
+ key, value = parse_fasta_search_result(raw_result)
60
+ return key, value
61
+
62
+
63
+ def parse_fasta_search_result(raw_result):
64
+ gene = parse_gene_from_fasta_entry(raw_result["description"])
65
+ key = f"Gene: {gene}"
66
+ value = raw_result["score"]
67
+ return key, value
68
+
69
+
70
+ def parse_pdb_search_result(raw_result):
71
  prot = raw_result["pdb_name"]
72
  chain = raw_result["chain_id"]
73
  value = raw_result["score"]
 
78
  return key, value
79
 
80
 
81
+ def parse_gene_from_fasta_entry(description):
82
+ after = description.split("GN=")[1]
83
+ gene = after.split(" ")[0]
84
+ return gene
85
+
86
+
87
  def switch_viz(new_choice):
88
+ if new_choice is None:
89
+ html = ""
90
+ title_update = gr.Markdown.update(visible=False)
91
+ description_update = gr.Markdown.update(value=None, visible=False)
92
+ else:
93
+ choice_parts = new_choice.split(".")
94
+ pdb_id, chain = choice_parts[0], choice_parts[1]
95
+ title_update = gr.Markdown.update(visible=True)
96
+ protein_name = get_protein_name(pdb_id)
97
 
98
+ new_value = f"""**PDB Title**: {protein_name}"""
99
 
100
+ description_update = gr.Markdown.update(value=new_value, visible=True)
101
+ html = render_html(pdb_id=pdb_id, chain=chain)
102
+ return html, title_update, description_update
103
 
104
 
105
  with gr.Blocks() as demo:
 
115
  with gr.Column():
116
  with gr.Row():
117
  with gr.Column():
118
+ seq_input = gr.Textbox(value="APTMPPPLPP", label="Input Sequence")
 
 
119
  n_results = gr.Number(5, label="N Results")
120
+ index_selector = gr.Dropdown(
121
+ choices=available_indexes,
122
+ value=available_indexes[0],
123
+ multiselect=False,
124
+ visible=True,
125
+ label="Index",
126
+ )
127
  search_button = gr.Button("Search", variant="primary")
128
  search_results = gr.Label(num_top_classes=20, label="Search Results")
129
  viz_header = gr.Markdown("## Visualization", visible=False)
 
144
  )
145
  search_button.click(
146
  search_and_display,
147
+ inputs=[seq_input, n_results, index_selector],
148
  outputs=[search_results, results_selector],
149
  )
150
  results_selector.change(
get_index.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import sys
 
2
  from pathlib import Path
3
 
4
  from huggingface_hub import snapshot_download
@@ -6,10 +8,11 @@ from huggingface_hub import snapshot_download
6
  from credentials import get_token
7
 
8
 
9
- def get_engine(index_repo: str, model_repo: str):
10
  index_path = Path(
11
  snapshot_download(index_repo, use_auth_token=get_token(), repo_type="dataset")
12
  )
 
13
  local_arch_path = Path(
14
  snapshot_download(model_repo, use_auth_token=get_token(), repo_type="model")
15
  )
@@ -18,4 +21,9 @@ def get_engine(index_repo: str, model_repo: str):
18
  ProteinSearchEngine,
19
  )
20
 
21
- return ProteinSearchEngine(data_path=index_path)
 
 
 
 
 
 
1
+ import os.path
2
  import sys
3
+ from glob import glob
4
  from pathlib import Path
5
 
6
  from huggingface_hub import snapshot_download
 
8
  from credentials import get_token
9
 
10
 
11
+ def get_engines(index_repo: str, model_repo: str):
12
  index_path = Path(
13
  snapshot_download(index_repo, use_auth_token=get_token(), repo_type="dataset")
14
  )
15
+
16
  local_arch_path = Path(
17
  snapshot_download(model_repo, use_auth_token=get_token(), repo_type="model")
18
  )
 
21
  ProteinSearchEngine,
22
  )
23
 
24
+ subindex_paths = glob(str(index_path / "*/"))
25
+ engines = {}
26
+ for subindex_path in subindex_paths:
27
+ subindex_name = os.path.basename(subindex_path).title()
28
+ engines[subindex_name] = ProteinSearchEngine(data_path=Path(subindex_path))
29
+ return engines