|
import gradio as gr |
|
import sys |
|
import random |
|
import os |
|
import pandas as pd |
|
import torch |
|
import itertools |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoTokenizer |
|
import shap |
|
|
|
sys.path.append("scripts/") |
|
from foldseek_util import get_struc_seq |
|
from utils import seed_everything, save_pickle |
|
from models import PLTNUM_PreTrainedModel |
|
from datasets_ import PLTNUMDataset |
|
|
|
|
|
class Config: |
|
def __init__(self): |
|
self.batch_size = 2 |
|
self.use_amp = False |
|
self.num_workers = 1 |
|
self.max_length = 512 |
|
self.used_sequence = "left" |
|
self.padding_side = "right" |
|
self.task = "classification" |
|
self.sequence_col = "sequence" |
|
self.seed = 42 |
|
self.max_evals = 10 |
|
|
|
|
|
|
|
def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): |
|
results = {"file_name": [], |
|
"raw prediction value": [], |
|
"binary prediction value": [] |
|
} |
|
file_names = [] |
|
input_sequences = [] |
|
|
|
for pdb_file in pdb_files: |
|
pdb_path = pdb_file.name |
|
os.system("chmod 777 bin/foldseek") |
|
sequences = get_foldseek_seq(pdb_path) |
|
if not sequences: |
|
results["file_name"].append(pdb_file.name.split("/")[-1]) |
|
results["raw prediction value"].append(None) |
|
results["binary prediction value"].append(None) |
|
continue |
|
|
|
sequence = sequences[2] if model_choice == "SaProt" else sequences[0] |
|
file_names.append(pdb_file.name.split("/")[-1]) |
|
input_sequences.append(sequence) |
|
|
|
raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, input_sequences, cfg) |
|
results["file_name"] = results["file_name"] + file_names |
|
results["raw prediction value"] = results["raw prediction value"] + raw_prediction |
|
results["binary prediction value"] = results["binary prediction value"] + binary_prediction |
|
|
|
df = pd.DataFrame(results) |
|
output_csv = "/tmp/predictions.csv" |
|
df.to_csv(output_csv, index=False) |
|
|
|
return output_csv |
|
|
|
def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()): |
|
try: |
|
if not sequence: |
|
return "No valid sequence provided." |
|
raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, [sequence], cfg) |
|
df = pd.DataFrame({"sequence": sequence, "raw prediction value": raw_prediction, "binary prediction value": binary_prediction}) |
|
output_csv = "/tmp/predictions.csv" |
|
df.to_csv(output_csv, index=False) |
|
|
|
return output_csv |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
def predict_stability_core(model_choice, organism_choice, sequences, cfg=Config()): |
|
cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3" |
|
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
cfg.architecture = model_choice |
|
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
|
|
output = predict(cfg, sequences) |
|
return output |
|
|
|
|
|
def get_foldseek_seq(pdb_path): |
|
parsed_seqs = get_struc_seq( |
|
"bin/foldseek", |
|
pdb_path, |
|
["A"], |
|
process_id=random.randint(0, 10000000), |
|
)["A"] |
|
return parsed_seqs |
|
|
|
|
|
def predict(cfg, sequences): |
|
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 |
|
cfg.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if cfg.used_sequence == "both": |
|
cfg.max_length += 1 |
|
|
|
seed_everything(cfg.seed) |
|
df = pd.DataFrame({cfg.sequence_col: sequences}) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
cfg.model_path, padding_side=cfg.padding_side |
|
) |
|
cfg.tokenizer = tokenizer |
|
dataset = PLTNUMDataset(cfg, df, train=False) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=cfg.batch_size, |
|
shuffle=False, |
|
num_workers=cfg.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) |
|
model.to(cfg.device) |
|
|
|
model.eval() |
|
predictions = [] |
|
|
|
for inputs, _ in dataloader: |
|
inputs = inputs.to(cfg.device) |
|
with torch.no_grad(): |
|
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): |
|
preds = ( |
|
torch.sigmoid(model(inputs)) |
|
if cfg.task == "classification" |
|
else model(inputs) |
|
) |
|
predictions += preds.cpu().tolist() |
|
|
|
predictions = list(itertools.chain.from_iterable(predictions)) |
|
|
|
return predictions, [1 if x > 0.5 else 0 for x in predictions] |
|
|
|
|
|
|
|
def calculate_shap_values_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): |
|
input_sequences = [] |
|
|
|
for pdb_file in pdb_files: |
|
pdb_path = pdb_file.name |
|
os.system("chmod 777 bin/foldseek") |
|
sequences = get_foldseek_seq(pdb_path) |
|
sequence = sequences[2] if model_choice == "SaProt" else sequences[0] |
|
input_sequences.append(sequence) |
|
|
|
shap_values = calculate_shap_values_core(model_choice, organism_choice, input_sequences, cfg) |
|
|
|
output_path = "/tmp/shap_values.pkl" |
|
save_pickle( |
|
output_path, shap_values |
|
) |
|
|
|
return output_path |
|
|
|
|
|
def calculate_shap_fn(texts, model, cfg): |
|
if len(texts) == 1: |
|
texts = texts[0] |
|
else: |
|
texts = texts.tolist() |
|
|
|
inputs = cfg.tokenizer( |
|
texts, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=cfg.max_length, |
|
) |
|
inputs = {k: v.to(cfg.device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
outputs = model(inputs) |
|
outputs = torch.sigmoid(outputs).detach().cpu().numpy() |
|
return outputs |
|
|
|
|
|
def calculate_shap_values_core(model_choice, organism_choice, sequences, cfg=Config()): |
|
cfg.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
seed_everything(cfg.seed) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
cfg.model_path, padding_side=cfg.padding_side |
|
) |
|
cfg.tokenizer = tokenizer |
|
|
|
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg).to(cfg.device) |
|
model.eval() |
|
|
|
|
|
explainer = shap.Explainer(lambda x: calculate_shap_fn(x, model, cfg), cfg.tokenizer) |
|
|
|
shap_values = explainer( |
|
sequences, |
|
batch_size=cfg.batch_size, |
|
max_evals=cfg.max_evals, |
|
) |
|
|
|
return shap_values |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# PLTNUM: Protein LifeTime Neural Model |
|
**Predict the protein half-life from its sequence or PDB file.** |
|
""" |
|
) |
|
|
|
gr.Image( |
|
"https://github.com/sagawatatsuya/PLTNUM/blob/main/model-image.png?raw=true", |
|
label="Model Image", |
|
) |
|
|
|
|
|
with gr.Row(): |
|
model_choice = gr.Radio( |
|
choices=["SaProt", "ESM2"], |
|
label="Select PLTNUM's base model.", |
|
value="SaProt", |
|
) |
|
organism_choice = gr.Radio( |
|
choices=["Mouse", "Human"], |
|
label="Select the target organism.", |
|
value="Mouse", |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Upload PDB File"): |
|
gr.Markdown("### Upload your PDB files:") |
|
pdb_files = gr.File(label="Upload PDB Files", file_count="multiple") |
|
predict_button = gr.Button("Predict Stability") |
|
prediction_output = gr.File( |
|
label="Download Predictions" |
|
) |
|
|
|
predict_button.click( |
|
fn=predict_stability_with_pdb, |
|
inputs=[model_choice, organism_choice, pdb_files], |
|
outputs=prediction_output, |
|
) |
|
|
|
calculate_shap_values_button = gr.Button("Calculate SHAP Values") |
|
shap_values_output = gr.File( |
|
label="Download SHAP Values" |
|
) |
|
calculate_shap_values_button.click( |
|
fn=calculate_shap_values_with_pdb, |
|
inputs=[model_choice, organism_choice, pdb_files], |
|
outputs=shap_values_output, |
|
) |
|
|
|
|
|
with gr.TabItem("Enter Protein Sequence"): |
|
gr.Markdown("### Enter the protein sequence:") |
|
sequence = gr.Textbox( |
|
label="Protein Sequence", |
|
placeholder="Enter your protein sequence here...", |
|
lines=8, |
|
) |
|
predict_button = gr.Button("Predict Stability") |
|
prediction_output = gr.File( |
|
label="Download Predictions" |
|
) |
|
|
|
predict_button.click( |
|
fn=predict_stability_with_sequence, |
|
inputs=[model_choice, organism_choice, sequence], |
|
outputs=prediction_output, |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### How to Use: |
|
- **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction. |
|
- **Select Organism**: Choose between 'Mouse' or 'Human'. |
|
- **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file. |
|
- **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence. |
|
- **Predict**: Click 'Predict Stability' to receive the prediction. |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### About the Tool |
|
This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input. |
|
""" |
|
) |
|
|
|
demo.launch() |
|
|