PiKaHa commited on
Commit
b3237f2
·
1 Parent(s): 0d226af

Add requirements.txt and app.py

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from transformers import AutoTokenizer, AutoModel, EsmModel
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import tensorflow as tf
9
+ import os
10
+ from keras.layers import TFSMLayer
11
+
12
+ print(f"TensorFlow Version: {tf.__version__}")
13
+
14
+ base_dir = "."
15
+
16
+ # Set random seed
17
+ SEED = 42
18
+ np.random.seed(SEED)
19
+ random.seed(SEED)
20
+ torch.manual_seed(SEED)
21
+ if torch.cuda.is_available():
22
+ torch.cuda.manual_seed(SEED)
23
+ torch.cuda.manual_seed_all(SEED)
24
+
25
+ # Ensure deterministic behavior
26
+ torch.backends.cudnn.deterministic = True
27
+ torch.backends.cudnn.benchmark = False
28
+
29
+
30
+ def load_model(model_path):
31
+ print(f"Loading model from {model_path}...")
32
+ #print(f"Loading model from {model_path} using TFSMLayer...")
33
+ #return TFSMLayer(model_path, call_endpoint="serving_default")
34
+ #return tf.keras.models.load_model(model_path)
35
+ return tf.saved_model.load(model_path)
36
+
37
+
38
+
39
+ # Load Random Forest models and configurations
40
+ print("Loading models...")
41
+ plant_models = {
42
+ "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
43
+ "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
44
+ "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
45
+ }
46
+
47
+ general_models = {
48
+ "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
49
+ "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
50
+ "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
51
+ }
52
+
53
+
54
+ # Function to generate embeddings
55
+ def get_embedding(sequence, esm_model_name, layer):
56
+ print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
57
+ tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
58
+ model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)
59
+
60
+ # Tokenize the sequence
61
+ inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
62
+
63
+ # Generate embeddings
64
+ with torch.no_grad():
65
+ outputs = model(**inputs)
66
+ hidden_states = outputs.hidden_states # Retrieve all hidden states
67
+ embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
68
+
69
+ return embedding
70
+
71
+
72
+ def predict_with_gpflow(model, X):
73
+ # Convert input to TensorFlow tensor
74
+ X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
75
+
76
+ # Get predictions
77
+ predict_fn = model.predict_f_compiled
78
+ mean, variance = predict_fn(X_tensor)
79
+
80
+ # Return mean and variance as numpy arrays
81
+ return mean.numpy().flatten(), variance.numpy().flatten()
82
+ # Function to predict based on user choice
83
+ def predict(sequence, prediction_type):
84
+ # Select the appropriate model set
85
+ selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
86
+
87
+ def process_target(target):
88
+ esm_model_name = selected_models[target]["esm_model"]
89
+ layer = selected_models[target]["layer"]
90
+ model = selected_models[target]["model"]
91
+
92
+ # Generate embedding
93
+ embedding = get_embedding(sequence, esm_model_name, layer)
94
+
95
+ if prediction_type == "Plant-Specific":
96
+ # Random Forest prediction
97
+ prediction = model.predict(embedding)[0]
98
+ return target, round(prediction, 2)
99
+ else:
100
+ # GPflow prediction
101
+ mean, variance = predict_with_gpflow(model, embedding)
102
+ return target, round(mean[0], 2), round(variance[0], 2)
103
+
104
+ # Predict for all targets in parallel
105
+ with ThreadPoolExecutor() as executor:
106
+ results = list(executor.map(process_target, selected_models.keys()))
107
+
108
+ # Format results
109
+ if prediction_type == "Plant-Specific":
110
+ formatted_results = [
111
+ ["Specificity", results[0][1]],
112
+ ["kcat\u1d9c", results[1][1]],
113
+ ["K\u1d9c", results[2][1]],
114
+ ]
115
+ else:
116
+ formatted_results = [
117
+ ["Specificity", results[0][1], results[0][2]],
118
+ ["kcat\u1d9c", results[1][1], results[1][2]],
119
+ ["K\u1d9c", results[2][1], results[2][2]],
120
+ ]
121
+
122
+ return formatted_results
123
+
124
+ # Define Gradio interface
125
+ print("Creating Gradio interface...")
126
+ interface = gr.Interface(
127
+ fn=predict,
128
+ inputs=[
129
+ gr.Textbox(label="Input Protein Sequence"), # Input: Text box for sequence
130
+ gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection
131
+ ],
132
+ outputs=gr.Dataframe(
133
+ headers=["Target", "Prediction", "Uncertainty (for General)"],
134
+ type="array"
135
+ ), # Output: Table
136
+ title="Rubisco Kinetics Prediction",
137
+ description=(
138
+ "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
139
+ "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
140
+ ),
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ joblib
5
+ numpy
6
+ scikit-learn
7
+ gpflow