File size: 2,464 Bytes
08b9eb6
 
 
 
 
 
 
 
 
dd9c8e6
0137aa6
 
 
 
 
285b88f
08b9eb6
 
 
 
 
 
 
 
 
 
dd9c8e6
 
 
 
0137aa6
 
dd9c8e6
61cedea
08b9eb6
61cedea
dd9c8e6
08b9eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
from esm_scripts.extract import run_demo
from esm import pretrained


model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.eval()
model_esm = model_esm.cuda()

# Load the model
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')


@spaces.GPU
def generate_caption(protein, prompt):
    # Process the image and the prompt
    # with open('/home/user/app/example.fasta', 'w') as f:
    #     f.write('>{}\n'.format("protein_name"))
    #     f.write('{}\n'.format(protein.strip()))
    # os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
    esm_emb = run_demo(protein_name='protein_name', protein_seq=protein, 
                       model=model_esm, alphabet=alphabet, 
                       include='per_tok', repr_layers=36, truncation_seq_length=1024)
    print("esm embedding generated")
    esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
    print("esm embedding processed")
    samples = {'name': ['protein_name'],
               'image': torch.unsqueeze(esm_emb, dim=0),
               'text_input': ['none'],
               'prompt': [prompt]}
    # Generate the output
    prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)

    return prediction

# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.

The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""

iface = gr.Interface(
    fn=generate_caption,
    inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
    outputs=gr.Textbox(label="Generated description"),
    description=description
)

# Launch the interface
iface.launch()