File size: 3,883 Bytes
ac117b5
 
 
 
 
 
 
 
 
81fb8a8
ac117b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec3465
ac117b5
 
 
cec3465
ac117b5
 
cec3465
ac117b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cc43e
ac117b5
 
 
cec3465
 
 
ac117b5
 
 
 
32cc43e
 
 
 
ac117b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cec3465
ac117b5
cec3465
ac117b5
 
 
 
cec3465
ac117b5
cec3465
ac117b5
32cc43e
8b16321
32cc43e
ac117b5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr

import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal
from mammal.keys import *



model_path="ibm/biomed.omics.bl.sm.ma-ted-458m"
# Load Model
model = Mammal.from_pretrained(model_path)
model.eval()

# Load Tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained(model_path)

#token for positive binding
positive_token_id=tokenizer_op.get_token_id("<1>")

# Default input proteins
protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"


def format_prompt(prot1,prot2):
    # Formatting prompt to match pre-training syntax
    return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"

def run_prompt(prompt):
    # Create and load sample
    sample_dict = dict()
    sample_dict[ENCODER_INPUTS_STR] = prompt

    # Tokenize
    sample_dict=tokenizer_op(
        sample_dict=sample_dict,
        key_in=ENCODER_INPUTS_STR,
        key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
        key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
    )
    sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(sample_dict[ENCODER_INPUTS_TOKENS])
    sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(sample_dict[ENCODER_INPUTS_ATTENTION_MASK])


    # Generate Prediction
    batch_dict = model.generate(
        [sample_dict],
        output_scores=True,
        return_dict_in_generate=True,
        max_new_tokens=5,
)


    # Get output
    generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
    score = batch_dict['model.out.scores'][0][1][positive_token_id].item()
    
    return generated_output,score

def create_and_run_prompt(prot1, prot2):
    prompt = format_prompt(prot1, prot2)
    res=prompt, *run_prompt(prompt=prompt)
    return res

def create_application():
    markup_text = f"""
# Mammal based Protein-Protein Interaction (PPI) demonstration

Given two protein sequences, estimate if the proteins interact or not.

### Using the model from 

 ```{model_path} ```
"""

    with gr.Blocks() as demo:
        gr.Markdown(markup_text)
        with gr.Row():
            prot1 = gr.Textbox(
                label="Protein 1 sequence",
                # info="standard",
                interactive=True,
                lines=1,
                value=protein_calmodulin,
            )
            prot2 = gr.Textbox(
                label="Protein 2 sequence",
                # info="standard",
                interactive=True,
                lines=1,
                value=protein_calcineurin,
            )
        with gr.Row():
            run_mammal = gr.Button("Run Mammal prompt for Protein-Protein Interaction",variant='primary')
        with gr.Row():
            prompt_box = gr.Textbox(label="Mammal prompt",lines=5)
        
        with gr.Row():
            decoded = gr.Textbox(label="Mammal output")
            run_mammal.click(
                fn=create_and_run_prompt,
                inputs=[prot1,prot2],
                outputs=[prompt_box,decoded,gr.Number(label='PPI score')]
            )
        with gr.Row():
            gr.Markdown("```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting")
            
    return demo

def main():
    demo = create_application()
    demo.launch(show_error=True, share=True)


if __name__ == "__main__":
    main()