File size: 6,482 Bytes
b3237f2
 
 
 
 
 
 
 
 
 
91cc187
b3237f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e372ce4
 
5f761b4
 
e372ce4
b3237f2
 
 
5f761b4
b3237f2
 
3ded8ec
b3237f2
bc8e1db
 
a03b8db
4c1817c
 
 
b3237f2
 
4c1817c
 
b3237f2
 
4e601e7
 
e372ce4
 
 
 
 
 
 
 
 
 
9a02516
 
e372ce4
 
 
 
 
 
3ded8ec
 
 
e372ce4
 
b3237f2
4e601e7
e372ce4
 
 
 
 
 
b3237f2
 
 
4e601e7
 
 
 
 
 
b3237f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e372ce4
b3237f2
 
 
 
 
6c96c76
 
 
 
b3237f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import joblib
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
import numpy as np
import random
import tensorflow as tf
import os
from keras.layers import TFSMLayer
import pandas as pd

base_dir = "."

# Set random seed
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def load_model(model_path):
    print(f"Loading model from {model_path}...")
    return tf.saved_model.load(model_path)


print("Loading models...")
plant_models = {
    "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
    "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
    "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
}

general_models = {
    "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
    "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
    "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
}


# Function to generate embeddings
def get_embedding(sequence, esm_model_name, layer):
    print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
    tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
    model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)

    # Tokenize the sequence
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)

    # Generate embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.hidden_states  # Retrieve all hidden states
        embedding = hidden_states[layer].mean(dim=1).numpy()  # Average pooling

    # Convert to DataFrame with named columns
    feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])}
    embedding_df = pd.DataFrame([feature_columns]) 
    print (embedding_df)
    return embedding_df.values, embedding_df


def predict_with_gpflow(model, X):
    print(model.signatures)
    # Convert input to TensorFlow tensor
    X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
    print (X_tensor.shape)
    # Get predictions
    #predict_fn = model.predict_f_compiled
    predict_fn = model.signatures["serving_default"]
    result = predict_fn(Xnew=X_tensor)  # Pass Xnew explicitly
    #mean, variance = predict_fn(Xnew=X_tensor)
    mean = result["output_0"].numpy()  # Adjust output key names if needed
    variance = result["output_1"].numpy()

    # Return mean and variance as numpy arrays
    #return mean.numpy().flatten(), variance.numpy().flatten()
    return mean.flatten(), variance.flatten()



def process_target(target, selected_models, sequence, prediction_type):
    """
    Process a single target for prediction using transformer embeddings and the specified model.
    """
    # Get model and embedding details
    esm_model_name = selected_models[target]["esm_model"]
    layer = selected_models[target]["layer"]
    model = selected_models[target]["model"]

    # Generate embeddings in the required format
    embedding, _ = get_embedding(sequence, esm_model_name, layer)
    embedding = embedding.astype(np.float64) 
    np.save(f"hf_embedding_{target}.npy", embedding)
    if prediction_type == "Plant-Specific":
        # Random Forest prediction
        y_pred = model.predict(embedding)[0]
        return target, round(y_pred, 2)
    else:
        # GPflow prediction
        print (esm_model_name)
        print (layer)
        print (model)
        y_pred, y_uncertainty = predict_with_gpflow(model, embedding)
        return target, round(y_pred[0], 2), round(y_uncertainty[0], 2)


def predict(sequence, prediction_type):
    """
    Predicts Specificity, kcatC, and KC for the given sequence and prediction type.
    """
    # Select the appropriate model set
    selected_models = plant_models if prediction_type == "Plant-Specific" else general_models

    # Predict for all targets in parallel
    with ThreadPoolExecutor() as executor:
        results = list(
            executor.map(
                lambda target: process_target(target, selected_models, sequence, prediction_type),
                selected_models.keys()
            )
        )

    # Format results
    if prediction_type == "Plant-Specific":
        formatted_results = [
            ["Specificity", results[0][1]],
            ["kcat\u1d9c", results[1][1]],
            ["K\u1d9c", results[2][1]],
        ]
    else:
        formatted_results = [
            ["Specificity", results[0][1], results[0][2]],
            ["kcat\u1d9c", results[1][1], results[1][2]],
            ["K\u1d9c", results[2][1], results[2][2]],
        ]

    return formatted_results


# Define Gradio interface
print("Creating Gradio interface...")
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="Input Protein Sequence",
                   value="MSPQTETKASVGFKAGVKEYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYHIEPVPGEETQFIAYVAYPLDLFEEGSVTNMFTSIVGNVFGFKALAALRLEDLRIPPAYTKTFQGPPHGIQVERDKLNKYGRPLLGCTIKPKLGLSAKNYGRAVYECLRGGLDFTKDDENVNSQPFMRWRDRFLFCAEAIYKSQAETGEIKGHYLNATAGTCEEMIKRAVFARELGVPIVMHDYLTGGFTANTSLSHYCRDNGLLLHIHRAMHAVIDRQKNHGMHFRVLAKALRLSGGDHIHAGTVVGKLEGDRESTLGFVDLLRDDYVEKDRSRGIFFTQDWVSLPGVLPVASGGIHVWHMPALTEIFGDDSVLQFGGGTLGHPWGNAPGAVANRVALEACVQARNEGRDLAVEGNEIIREACKWSPELAAACEVWKEITFNFPTIDKLDGQE",
            lines=10, 
                  ),  # Input: Text box for sequence
        gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"),  # Dropdown for selection
    ],
    outputs=gr.Dataframe(
        headers=["Target", "Prediction", "Uncertainty (for General)"], 
        type="array"
    ),  # Output: Table
    title="Rubisco Kinetics Prediction",
    description=(
        "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
        "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
    ),
)

if __name__ == "__main__":
    interface.launch()