wenkai commited on
Commit
aad9fe1
1 Parent(s): ea37187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -25
app.py CHANGED
@@ -7,17 +7,39 @@ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
 
10
  from esm import pretrained, FastaBatchedDataset
11
  from data.evaluate_data.utils import Ontology
12
  import difflib
13
  import re
14
-
15
-
16
- # Load the model
17
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
18
- model.load_checkpoint("model/checkpoint_mf2.pth")
19
- model.to('cuda')
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
22
  model_esm.to('cuda')
23
  model_esm.eval()
@@ -39,7 +61,7 @@ choices = {x.lower(): x for x in choices_mf}
39
 
40
 
41
  @spaces.GPU
42
- def generate_caption(protein, prompt):
43
  # Process the image and the prompt
44
  # with open('/home/user/app/example.fasta', 'w') as f:
45
  # f.write('>{}\n'.format("protein_name"))
@@ -122,8 +144,9 @@ def generate_caption(protein, prompt):
122
  'text_input': ['none'],
123
  'prompt': [prompt]}
124
 
 
125
  # Generate the output
126
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
127
  repetition_penalty=1.0)
128
 
129
  x = prediction[0]
@@ -140,12 +163,17 @@ def generate_caption(protein, prompt):
140
  pred_terms.append(t_standard+f'({prob})')
141
  temp.append(t_standard)
142
  if prompt == 'none':
143
- res_str = "No available predictions for this protein, you can try to remove prompt!"
144
  else:
145
- res_str = "No available predictions for this protein, you can try another protein sequence!"
146
  if len(pred_terms) == 0:
147
  return res_str
148
- res_str = f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}"
 
 
 
 
 
149
  return res_str
150
  # return "test"
151
 
@@ -155,7 +183,6 @@ description = """Quick demonstration of the FAPM model for protein function pred
155
 
156
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
157
 
158
-
159
  # iface = gr.Interface(
160
  # fn=generate_caption,
161
  # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
@@ -165,7 +192,6 @@ The model used in this app is available at [Hugging Face Model Hub](https://hugg
165
  # # Launch the interface
166
  # iface.launch()
167
 
168
-
169
  css = """
170
  #output {
171
  height: 500px;
@@ -179,30 +205,29 @@ with gr.Blocks(css=css) as demo:
179
  with gr.Tab(label="Protein caption"):
180
  with gr.Row():
181
  with gr.Column():
 
182
  input_protein = gr.Textbox(type="text", label="Upload sequence")
183
- # model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
184
  prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
185
  submit_btn = gr.Button(value="Submit")
186
  with gr.Column():
187
  output_text = gr.Textbox(label="Output Text")
188
- # train index 127, 266, 738, 1060 test index 4
189
  gr.Examples(
190
  examples=[
191
- ["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
192
- ["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
193
- ["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
194
- ['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
195
- ['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
196
- ['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
197
  ],
198
- inputs=[input_protein, prompt],
199
  outputs=[output_text],
200
  fn=generate_caption,
201
  cache_examples=True,
202
  label='Try examples'
203
  )
204
-
205
- submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
206
 
207
  demo.launch(debug=True)
208
 
 
7
  from lavis.models.base_model import FAPMConfig
8
  import spaces
9
  import gradio as gr
10
+ # from esm_scripts.extract import run_demo
11
  from esm import pretrained, FastaBatchedDataset
12
  from data.evaluate_data.utils import Ontology
13
  import difflib
14
  import re
15
+ from transformers import MistralForCausalLM
16
+
17
+ # Load the trained model
18
+ def get_model(type='Molecule Function'):
19
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
20
+ if type == 'Molecule Function':
21
+ model.load_checkpoint("model/checkpoint_mf2.pth")
22
+ model.to('cuda')
23
+ elif type == 'Biological Process':
24
+ model.load_checkpoint("model/checkpoint_bp1.pth")
25
+ model.to('cuda')
26
+ elif type == 'Cellar Component':
27
+ model.load_checkpoint("model/checkpoint_cc2.pth")
28
+ model.to('cuda')
29
+ return model
30
+
31
+
32
+ models = {
33
+ 'Molecule Function': get_model('Molecule Function'),
34
+ 'Biological Process': get_model('Biological Process'),
35
+ 'Cellar Component': get_model('Cellar Component'),
36
+ }
37
+
38
+
39
+ # Load the mistral model
40
+ mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
41
+
42
+ # Load ESM2 model
43
  model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
44
  model_esm.to('cuda')
45
  model_esm.eval()
 
61
 
62
 
63
  @spaces.GPU
64
+ def generate_caption(model_id, protein, prompt):
65
  # Process the image and the prompt
66
  # with open('/home/user/app/example.fasta', 'w') as f:
67
  # f.write('>{}\n'.format("protein_name"))
 
144
  'text_input': ['none'],
145
  'prompt': [prompt]}
146
 
147
+ model = models[model_id]
148
  # Generate the output
149
+ prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
150
  repetition_penalty=1.0)
151
 
152
  x = prediction[0]
 
163
  pred_terms.append(t_standard+f'({prob})')
164
  temp.append(t_standard)
165
  if prompt == 'none':
166
+ res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
167
  else:
168
+ res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
169
  if len(pred_terms) == 0:
170
  return res_str
171
+ if model_id == 'Molecule Function':
172
+ res_str = f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}"
173
+ elif model_id == 'Biological Process':
174
+ res_str = f"Based on the given amino acid sequence, it is likely involved in the {', '.join(pred_terms)}"
175
+ elif model_id == 'Cellar Component':
176
+ res_str = f"Based on the given amino acid sequence, it's subcellular localization is within the {', '.join(pred_terms)}"
177
  return res_str
178
  # return "test"
179
 
 
183
 
184
  The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
185
 
 
186
  # iface = gr.Interface(
187
  # fn=generate_caption,
188
  # inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
 
192
  # # Launch the interface
193
  # iface.launch()
194
 
 
195
  css = """
196
  #output {
197
  height: 500px;
 
205
  with gr.Tab(label="Protein caption"):
206
  with gr.Row():
207
  with gr.Column():
208
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
209
  input_protein = gr.Textbox(type="text", label="Upload sequence")
 
210
  prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
211
  submit_btn = gr.Button(value="Submit")
212
  with gr.Column():
213
  output_text = gr.Textbox(label="Output Text")
214
+ # O14813 train index 127, 266, 738, 1060 test index 4
215
  gr.Examples(
216
  examples=[
217
+ ["Molecule Function", "MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
218
+ ["Molecule Function", "MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
219
+ ["Molecule Function", "MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
220
+ ["Molecule Function", 'MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
221
+ ["Molecule Function", 'MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
222
+ ["Molecule Function", 'MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
223
  ],
224
+ inputs=[model_selector, input_protein, prompt],
225
  outputs=[output_text],
226
  fn=generate_caption,
227
  cache_examples=True,
228
  label='Try examples'
229
  )
230
+ submit_btn.click(generate_caption, [model_selector, input_protein, prompt], [output_text])
 
231
 
232
  demo.launch(debug=True)
233