PiKaHa's picture
Update app.py with transformer embeddings and prediction pipeline
6c96c76
import gradio as gr
import joblib
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
import numpy as np
import random
import tensorflow as tf
import os
from keras.layers import TFSMLayer
import pandas as pd
base_dir = "."
# Set random seed
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_model(model_path):
print(f"Loading model from {model_path}...")
return tf.saved_model.load(model_path)
print("Loading models...")
plant_models = {
"Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
"kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
"KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
}
general_models = {
"Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
"kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
"KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
}
# Function to generate embeddings
def get_embedding(sequence, esm_model_name, layer):
print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)
# Tokenize the sequence
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
hidden_states = outputs.hidden_states # Retrieve all hidden states
embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
# Convert to DataFrame with named columns
feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])}
embedding_df = pd.DataFrame([feature_columns])
print (embedding_df)
return embedding_df.values, embedding_df
def predict_with_gpflow(model, X):
print(model.signatures)
# Convert input to TensorFlow tensor
X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
print (X_tensor.shape)
# Get predictions
#predict_fn = model.predict_f_compiled
predict_fn = model.signatures["serving_default"]
result = predict_fn(Xnew=X_tensor) # Pass Xnew explicitly
#mean, variance = predict_fn(Xnew=X_tensor)
mean = result["output_0"].numpy() # Adjust output key names if needed
variance = result["output_1"].numpy()
# Return mean and variance as numpy arrays
#return mean.numpy().flatten(), variance.numpy().flatten()
return mean.flatten(), variance.flatten()
def process_target(target, selected_models, sequence, prediction_type):
"""
Process a single target for prediction using transformer embeddings and the specified model.
"""
# Get model and embedding details
esm_model_name = selected_models[target]["esm_model"]
layer = selected_models[target]["layer"]
model = selected_models[target]["model"]
# Generate embeddings in the required format
embedding, _ = get_embedding(sequence, esm_model_name, layer)
embedding = embedding.astype(np.float64)
np.save(f"hf_embedding_{target}.npy", embedding)
if prediction_type == "Plant-Specific":
# Random Forest prediction
y_pred = model.predict(embedding)[0]
return target, round(y_pred, 2)
else:
# GPflow prediction
print (esm_model_name)
print (layer)
print (model)
y_pred, y_uncertainty = predict_with_gpflow(model, embedding)
return target, round(y_pred[0], 2), round(y_uncertainty[0], 2)
def predict(sequence, prediction_type):
"""
Predicts Specificity, kcatC, and KC for the given sequence and prediction type.
"""
# Select the appropriate model set
selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
# Predict for all targets in parallel
with ThreadPoolExecutor() as executor:
results = list(
executor.map(
lambda target: process_target(target, selected_models, sequence, prediction_type),
selected_models.keys()
)
)
# Format results
if prediction_type == "Plant-Specific":
formatted_results = [
["Specificity", results[0][1]],
["kcat\u1d9c", results[1][1]],
["K\u1d9c", results[2][1]],
]
else:
formatted_results = [
["Specificity", results[0][1], results[0][2]],
["kcat\u1d9c", results[1][1], results[1][2]],
["K\u1d9c", results[2][1], results[2][2]],
]
return formatted_results
# Define Gradio interface
print("Creating Gradio interface...")
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Input Protein Sequence",
value="MSPQTETKASVGFKAGVKEYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGTWTTVWTDGLTSLDRYKGRCYHIEPVPGEETQFIAYVAYPLDLFEEGSVTNMFTSIVGNVFGFKALAALRLEDLRIPPAYTKTFQGPPHGIQVERDKLNKYGRPLLGCTIKPKLGLSAKNYGRAVYECLRGGLDFTKDDENVNSQPFMRWRDRFLFCAEAIYKSQAETGEIKGHYLNATAGTCEEMIKRAVFARELGVPIVMHDYLTGGFTANTSLSHYCRDNGLLLHIHRAMHAVIDRQKNHGMHFRVLAKALRLSGGDHIHAGTVVGKLEGDRESTLGFVDLLRDDYVEKDRSRGIFFTQDWVSLPGVLPVASGGIHVWHMPALTEIFGDDSVLQFGGGTLGHPWGNAPGAVANRVALEACVQARNEGRDLAVEGNEIIREACKWSPELAAACEVWKEITFNFPTIDKLDGQE",
lines=10,
), # Input: Text box for sequence
gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection
],
outputs=gr.Dataframe(
headers=["Target", "Prediction", "Uncertainty (for General)"],
type="array"
), # Output: Table
title="Rubisco Kinetics Prediction",
description=(
"Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
"Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
),
)
if __name__ == "__main__":
interface.launch()