File size: 9,880 Bytes
f0bbe14 178b187 33e1d55 9329e39 e4d81ca 9329e39 93d358b a418c62 2683162 ef0bf75 93d358b ef0bf75 f717639 ef0bf75 9329e39 3bb8b09 93d358b 3bb8b09 1d7f8f6 aa82e0a 1d7f8f6 42d0bd4 7cee862 42d0bd4 20c06cf 946fcbf 1d7f8f6 946fcbf 1d7f8f6 946fcbf 1d7f8f6 7cee862 946fcbf 1d7f8f6 482be2c 1d7f8f6 20c06cf a668d8b 1d7f8f6 a668d8b 1d7f8f6 946fcbf 1d7f8f6 a668d8b 42d0bd4 a668d8b 42d0bd4 a668d8b 3bb8b09 93a77af 178b187 64c1665 178b187 42d0bd4 9329e39 42d0bd4 9329e39 1d7f8f6 e4d81ca 9329e39 4cbe855 3bb8b09 e4d81ca 1d7f8f6 42d0bd4 9329e39 f0bbe14 93d358b 77582ad 93d358b f0bbe14 f081356 f0bbe14 aa82e0a f0bbe14 1d7f8f6 f0bbe14 a668d8b aa82e0a f0bbe14 93d358b f0bbe14 1d7f8f6 f0bbe14 a668d8b f0bbe14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import gradio as gr
import sys
import random
import os
import pandas as pd
import torch
import itertools
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import shap
sys.path.append("scripts/")
from foldseek_util import get_struc_seq
from utils import seed_everything, save_pickle
from models import PLTNUM_PreTrainedModel
from datasets_ import PLTNUMDataset
class Config:
def __init__(self):
self.batch_size = 2
self.use_amp = False
self.num_workers = 1
self.max_length = 512
self.used_sequence = "left"
self.padding_side = "right"
self.task = "classification"
self.sequence_col = "sequence"
self.seed = 42
self.max_evals = 10
def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
results = {"file_name": [],
"raw prediction value": [],
"binary prediction value": []
}
file_names = []
input_sequences = []
for pdb_file in pdb_files:
pdb_path = pdb_file.name
os.system("chmod 777 bin/foldseek")
sequences = get_foldseek_seq(pdb_path)
if not sequences:
results["file_name"].append(pdb_file.name.split("/")[-1])
results["raw prediction value"].append(None)
results["binary prediction value"].append(None)
continue
sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
file_names.append(pdb_file.name.split("/")[-1])
input_sequences.append(sequence)
raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, input_sequences, cfg)
results["file_name"] = results["file_name"] + file_names
results["raw prediction value"] = results["raw prediction value"] + raw_prediction
results["binary prediction value"] = results["binary prediction value"] + binary_prediction
df = pd.DataFrame(results)
output_csv = "/tmp/predictions.csv"
df.to_csv(output_csv, index=False)
return output_csv
def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()):
try:
if not sequence:
return "No valid sequence provided."
raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, [sequence], cfg)
df = pd.DataFrame({"sequence": sequence, "raw prediction value": raw_prediction, "binary prediction value": binary_prediction})
output_csv = "/tmp/predictions.csv"
df.to_csv(output_csv, index=False)
return output_csv
except Exception as e:
return f"An error occurred: {str(e)}"
def predict_stability_core(model_choice, organism_choice, sequences, cfg=Config()):
cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.architecture = model_choice
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
output = predict(cfg, sequences)
return output
def get_foldseek_seq(pdb_path):
parsed_seqs = get_struc_seq(
"bin/foldseek",
pdb_path,
["A"],
process_id=random.randint(0, 10000000),
)["A"]
return parsed_seqs
def predict(cfg, sequences):
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
if cfg.used_sequence == "both":
cfg.max_length += 1
seed_everything(cfg.seed)
df = pd.DataFrame({cfg.sequence_col: sequences})
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, padding_side=cfg.padding_side
)
cfg.tokenizer = tokenizer
dataset = PLTNUMDataset(cfg, df, train=False)
dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
model.to(cfg.device)
model.eval()
predictions = []
for inputs, _ in dataloader:
inputs = inputs.to(cfg.device)
with torch.no_grad():
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
preds = (
torch.sigmoid(model(inputs))
if cfg.task == "classification"
else model(inputs)
)
predictions += preds.cpu().tolist()
predictions = list(itertools.chain.from_iterable(predictions))
return predictions, [1 if x > 0.5 else 0 for x in predictions]
def calculate_shap_values_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
input_sequences = []
for pdb_file in pdb_files:
pdb_path = pdb_file.name
os.system("chmod 777 bin/foldseek")
sequences = get_foldseek_seq(pdb_path)
sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
input_sequences.append(sequence)
shap_values = calculate_shap_values_core(model_choice, organism_choice, input_sequences, cfg)
output_path = "/tmp/shap_values.pkl"
save_pickle(
output_path, shap_values
)
return output_path
def calculate_shap_fn(texts, model, cfg):
if len(texts) == 1:
texts = texts[0]
else:
texts = texts.tolist()
inputs = cfg.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=cfg.max_length,
)
inputs = {k: v.to(cfg.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(inputs)
outputs = torch.sigmoid(outputs).detach().cpu().numpy()
return outputs
def calculate_shap_values_core(model_choice, organism_choice, sequences, cfg=Config()):
cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.architecture = model_choice
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
seed_everything(cfg.seed)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, padding_side=cfg.padding_side
)
cfg.tokenizer = tokenizer
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg).to(cfg.device)
model.eval()
# build an explainer using a token masker
explainer = shap.Explainer(lambda x: calculate_shap_fn(x, model, cfg), cfg.tokenizer)
shap_values = explainer(
sequences,
batch_size=cfg.batch_size,
max_evals=cfg.max_evals,
)
return shap_values
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# PLTNUM: Protein LifeTime Neural Model
**Predict the protein half-life from its sequence or PDB file.**
"""
)
gr.Image(
"https://raw.githubusercontent.com/sagawatatsuya/PLTNUM/main/model-image.png",
label="Model Image",
)
# Model and Organism selection in the same row to avoid layout issues
with gr.Row():
model_choice = gr.Radio(
choices=["SaProt", "ESM2"],
label="Select PLTNUM's base model.",
value="SaProt",
)
organism_choice = gr.Radio(
choices=["Mouse", "Human"],
label="Select the target organism.",
value="Mouse",
)
with gr.Tabs():
with gr.TabItem("Upload PDB File"):
gr.Markdown("### Upload your PDB files:")
pdb_files = gr.File(label="Upload PDB Files", file_count="multiple")
predict_button = gr.Button("Predict Stability")
prediction_output = gr.File(
label="Download Predictions"
)
predict_button.click(
fn=predict_stability_with_pdb,
inputs=[model_choice, organism_choice, pdb_files],
outputs=prediction_output,
)
calculate_shap_values_button = gr.Button("Calculate SHAP Values")
shap_values_output = gr.File(
label="Download SHAP Values"
)
calculate_shap_values_button.click(
fn=calculate_shap_values_with_pdb,
inputs=[model_choice, organism_choice, pdb_files],
outputs=shap_values_output,
)
with gr.TabItem("Enter Protein Sequence"):
gr.Markdown("### Enter the protein sequence:")
sequence = gr.Textbox(
label="Protein Sequence",
placeholder="Enter your protein sequence here...",
lines=8,
)
predict_button = gr.Button("Predict Stability")
prediction_output = gr.File(
label="Download Predictions"
)
predict_button.click(
fn=predict_stability_with_sequence,
inputs=[model_choice, organism_choice, sequence],
outputs=prediction_output,
)
gr.Markdown(
"""
### How to Use:
- **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction.
- **Select Organism**: Choose between 'Mouse' or 'Human'.
- **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
- **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.
- **Predict**: Click 'Predict Stability' to receive the prediction.
"""
)
gr.Markdown(
"""
### About the Tool
This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input.
"""
)
demo.launch()
|