File size: 5,827 Bytes
f8080fc
 
19dfa7a
 
 
 
 
 
f8080fc
 
19dfa7a
f8080fc
 
 
 
 
 
 
 
 
 
83811e8
f8080fc
19dfa7a
f8080fc
19dfa7a
f8080fc
 
 
 
 
 
 
 
 
 
 
19dfa7a
f8080fc
 
 
 
 
 
 
 
 
19dfa7a
 
 
 
 
 
 
 
f8080fc
19dfa7a
 
f8080fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
 
f8080fc
 
19dfa7a
 
 
 
 
 
f8080fc
 
 
19dfa7a
f8080fc
19dfa7a
 
 
 
f8080fc
 
19dfa7a
f8080fc
 
19dfa7a
 
 
 
f8080fc
19dfa7a
 
f8080fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
 
f8080fc
 
 
 
 
f98cc68
f8080fc
 
 
19dfa7a
f8080fc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import torch
from mammal.keys import (
    CLS_PRED,
    ENCODER_INPUTS_ATTENTION_MASK,
    ENCODER_INPUTS_STR,
    ENCODER_INPUTS_TOKENS,
)
from mammal.model import Mammal

from mammal_demo.demo_framework import MammalObjectBroker, MammalTask


class PpiTask(MammalTask):
    def __init__(self, model_dict):
        super().__init__(name="Protein-Protein Interaction", model_dict=model_dict)
        self.description = "Protein-Protein Interaction (PPI)"
        self.examples = {
            "protein_calmodulin": "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK",
            "protein_calcineurin": "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ",
        }
        self.markup_text = f"""
    # Mammal based {self.description} demonstration

    Given two protein sequences, estimate if the proteins interact or not."""

    @staticmethod
    def positive_token_id(model_holder: MammalObjectBroker):
        """token for positive binding

        Args:
            model (MammalTrainedModel): model holding tokenizer

        Returns:
            int: id of positive binding token
        """
        return model_holder.tokenizer_op.get_token_id("<1>")

    def generate_prompt(self, prot1, prot2):
        """Formatting prompt to match pre-training syntax

        Args:
            prot1 (str): sequance of protein number 1
            prot2 (str): sequance of protein number 2

        Returns:
            str: prompt
        """
        prompt = (
            "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
            + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
            + f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"
            + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
            + f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
        )
        return prompt

    def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
        # Create and load sample
        sample_dict = dict()
        prompt = self.generate_prompt(*sample_inputs)
        sample_dict[ENCODER_INPUTS_STR] = prompt

        # Tokenize
        sample_dict = model_holder.tokenizer_op(
            sample_dict=sample_dict,
            key_in=ENCODER_INPUTS_STR,
            key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
            key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
        )
        sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
            sample_dict[ENCODER_INPUTS_TOKENS]
        )
        sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
            sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
        )
        return sample_dict

    def run_model(self, sample_dict, model: Mammal):
        # Generate Prediction
        batch_dict = model.generate(
            [sample_dict],
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=5,
        )
        return batch_dict

    def decode_output(self, batch_dict, model_holder: MammalObjectBroker):

        # Get output
        generated_output = model_holder.tokenizer_op._tokenizer.decode(
            batch_dict[CLS_PRED][0]
        )
        score = batch_dict["model.out.scores"][0][1][
            self.positive_token_id(model_holder)
        ].item()

        return generated_output, score

    def create_and_run_prompt(self, model_name, protein1, protein2):
        model_holder = self.model_dict[model_name]
        sample_inputs = {"prot1": protein1, "prot2": protein2}
        sample_dict = self.crate_sample_dict(
            sample_inputs=sample_inputs, model_holder=model_holder
        )
        prompt = sample_dict[ENCODER_INPUTS_STR]
        batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
        res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
        return res

    def create_demo(self, model_name_widget: gr.component):

        # """
        # ### Using the model from

        # ```{model} ```
        # """
        with gr.Group() as demo:
            gr.Markdown(self.markup_text)
            with gr.Row():
                prot1 = gr.Textbox(
                    label="Protein 1 sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_calmodulin"],
                )
                prot2 = gr.Textbox(
                    label="Protein 2 sequence",
                    # info="standard",
                    interactive=True,
                    lines=3,
                    value=self.examples["protein_calcineurin"],
                )
            with gr.Row():
                run_mammal: gr.Button = gr.Button(
                    "Run Mammal prompt for Protein-Protein Interaction",
                    variant="primary",
                )
            with gr.Row():
                prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
            with gr.Row():
                decoded = gr.Textbox(label="Mammal output")
                score_box = gr.Number(label="PPI score")
                run_mammal.click(
                    fn=self.create_and_run_prompt,
                    inputs=[model_name_widget, prot1, prot2],
                    outputs=[prompt_box, decoded, score_box],
                )
            with gr.Row():
                gr.Markdown(
                    "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
                )
            demo.visible = False
            return demo