|
import gradio as gr |
|
import joblib |
|
from concurrent.futures import ThreadPoolExecutor |
|
from transformers import AutoTokenizer, AutoModel, EsmModel |
|
import torch |
|
import numpy as np |
|
import random |
|
import tensorflow as tf |
|
import os |
|
from keras.layers import TFSMLayer |
|
import pandas as pd |
|
|
|
base_dir = "." |
|
|
|
|
|
SEED = 42 |
|
np.random.seed(SEED) |
|
random.seed(SEED) |
|
torch.manual_seed(SEED) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(SEED) |
|
torch.cuda.manual_seed_all(SEED) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def load_model(model_path): |
|
print(f"Loading model from {model_path}...") |
|
return tf.saved_model.load(model_path) |
|
|
|
|
|
print("Loading models...") |
|
plant_models = { |
|
"Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6}, |
|
"kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11}, |
|
"KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4}, |
|
} |
|
|
|
general_models = { |
|
"Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33}, |
|
"kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7}, |
|
"KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26}, |
|
} |
|
|
|
|
|
|
|
def get_embedding(sequence, esm_model_name, layer): |
|
print(f"Generating embeddings using {esm_model_name}, Layer {layer}...") |
|
tokenizer = AutoTokenizer.from_pretrained(esm_model_name) |
|
model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True) |
|
|
|
|
|
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
hidden_states = outputs.hidden_states |
|
embedding = hidden_states[layer].mean(dim=1).numpy() |
|
|
|
|
|
feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])} |
|
embedding_df = pd.DataFrame([feature_columns]) |
|
print (embedding_df) |
|
return embedding_df.values, embedding_df |
|
|
|
|
|
def predict_with_gpflow(model, X): |
|
print(model.signatures) |
|
|
|
X_tensor = tf.convert_to_tensor(X, dtype=tf.float64) |
|
print (X_tensor.shape) |
|
|
|
|
|
predict_fn = model.signatures["serving_default"] |
|
result = predict_fn(Xnew=X_tensor) |
|
|
|
mean = result["output_0"].numpy() |
|
variance = result["output_1"].numpy() |
|
|
|
|
|
|
|
return mean.flatten(), variance.flatten() |
|
|
|
|
|
|
|
def process_target(target, selected_models, sequence, prediction_type): |
|
""" |
|
Process a single target for prediction using transformer embeddings and the specified model. |
|
""" |
|
|
|
esm_model_name = selected_models[target]["esm_model"] |
|
layer = selected_models[target]["layer"] |
|
model = selected_models[target]["model"] |
|
|
|
|
|
embedding, _ = get_embedding(sequence, esm_model_name, layer) |
|
embedding = embedding.astype(np.float64) |
|
np.save(f"hf_embedding_{target}.npy", embedding) |
|
if prediction_type == "Plant-Specific": |
|
|
|
y_pred = model.predict(embedding)[0] |
|
return target, round(y_pred, 2) |
|
else: |
|
|
|
print (esm_model_name) |
|
print (layer) |
|
print (model) |
|
y_pred, y_uncertainty = predict_with_gpflow(model, embedding) |
|
return target, round(y_pred[0], 2), round(y_uncertainty[0], 2) |
|
|
|
|
|
def predict(sequence, prediction_type): |
|
""" |
|
Predicts Specificity, kcatC, and KC for the given sequence and prediction type. |
|
""" |
|
|
|
selected_models = plant_models if prediction_type == "Plant-Specific" else general_models |
|
|
|
|
|
with ThreadPoolExecutor() as executor: |
|
results = list( |
|
executor.map( |
|
lambda target: process_target(target, selected_models, sequence, prediction_type), |
|
selected_models.keys() |
|
) |
|
) |
|
|
|
|
|
if prediction_type == "Plant-Specific": |
|
formatted_results = [ |
|
["Specificity", results[0][1]], |
|
["kcat\u1d9c", results[1][1]], |
|
["K\u1d9c", results[2][1]], |
|
] |
|
else: |
|
formatted_results = [ |
|
["Specificity", results[0][1], results[0][2]], |
|
["kcat\u1d9c", results[1][1], results[1][2]], |
|
["K\u1d9c", results[2][1], results[2][2]], |
|
] |
|
|
|
return formatted_results |
|
|
|
|
|
|
|
print("Creating Gradio interface...") |
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox(label="Input Protein Sequence", |
|
value="MSPQTETKASVGFKAGVKEYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYHIEPVPGEETQFIAYVAYPLDLFEEGSVTNMFTSIVGNVFGFKALAALRLEDLRIPPAYTKTFQGPPHGIQVERDKLNKYGRPLLGCTIKPKLGLSAKNYGRAVYECLRGGLDFTKDDENVNSQPFMRWRDRFLFCAEAIYKSQAETGEIKGHYLNATAGTCEEMIKRAVFARELGVPIVMHDYLTGGFTANTSLSHYCRDNGLLLHIHRAMHAVIDRQKNHGMHFRVLAKALRLSGGDHIHAGTVVGKLEGDRESTLGFVDLLRDDYVEKDRSRGIFFTQDWVSLPGVLPVASGGIHVWHMPALTEIFGDDSVLQFGGGTLGHPWGNAPGAVANRVALEACVQARNEGRDLAVEGNEIIREACKWSPELAAACEVWKEITFNFPTIDKLDGQE", |
|
lines=10, |
|
), |
|
gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), |
|
], |
|
outputs=gr.Dataframe( |
|
headers=["Target", "Prediction", "Uncertainty (for General)"], |
|
type="array" |
|
), |
|
title="Rubisco Kinetics Prediction", |
|
description=( |
|
"Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). " |
|
"Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions." |
|
), |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|