|
import gradio as gr |
|
import sys |
|
import random |
|
import os |
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoTokenizer |
|
sys.path.append("scripts/") |
|
from foldseek_util import get_struc_seq |
|
from utils import seed_everything |
|
from models import PLTNUM_PreTrainedModel |
|
from datasets import PLTNUMDataset |
|
|
|
class Config: |
|
batch_size = 2 |
|
use_amp = False |
|
num_workers = 1 |
|
max_length = 512 |
|
used_sequence = "left" |
|
padding_side = "right" |
|
task = "classification" |
|
sequence_col = "sequence" |
|
|
|
|
|
def predict_stability(cfg, model_choice, organism_choice, pdb_file=None, sequence=None): |
|
|
|
if pdb_file: |
|
pdb_path = pdb_file.name |
|
os.system("chmod 777 bin/foldseek") |
|
sequences = get_foldseek_seq(pdb_path) |
|
if not sequences: |
|
return "Failed to extract sequence from the PDB file." |
|
if model_choice == "SaProt": |
|
sequence = sequences[2] |
|
else: |
|
sequence = sequences[0] |
|
|
|
if organism_choice == "Human": |
|
cell_line = "HeLa" |
|
else: |
|
cell_line = "NIH3T3" |
|
|
|
if sequence: |
|
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
cfg.architecture = model_choice |
|
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
output = predict(cfg, sequence) |
|
return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..." |
|
else: |
|
return "No valid input provided." |
|
|
|
|
|
def get_foldseek_seq(pdb_path): |
|
parsed_seqs = get_struc_seq( |
|
"bin/foldseek", |
|
pdb_path, |
|
["A"], |
|
process_id=random.randint(0, 10000000), |
|
)["A"] |
|
return parsed_seqs |
|
|
|
|
|
def predict(cfg, sequence): |
|
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 |
|
cfg.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if cfg.used_sequence == "both": |
|
cfg.max_length += 1 |
|
|
|
seed_everything(cfg.seed) |
|
|
|
df = pd.DataFrame({cfg.sequence_col: [sequence]}) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
cfg.model_path, padding_side=cfg.padding_side |
|
) |
|
cfg.tokenizer = tokenizer |
|
|
|
dataset = PLTNUMDataset(cfg, df, train=False) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=cfg.batch_size, |
|
shuffle=False, |
|
num_workers=cfg.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) |
|
model.to(cfg.device) |
|
|
|
|
|
model.eval() |
|
predictions = [] |
|
|
|
for inputs, _ in dataloader: |
|
inputs = inputs.to(cfg.device) |
|
with torch.no_grad(): |
|
with torch.amp.autocast(enabled=cfg.use_amp): |
|
preds = ( |
|
torch.sigmoid(model(inputs)) |
|
if cfg.task == "classification" |
|
else model(inputs) |
|
) |
|
predictions += preds.cpu().tolist() |
|
outputs = {} |
|
outputs["raw prediction values"] = predictions |
|
outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions] |
|
return outputs |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# PLTNUM: Protein LifeTime Neural Model |
|
**Predict the protein half-life from its sequence or PDB file.** |
|
""" |
|
) |
|
|
|
gr.Image("https://github.com/sagawatatsuya/PLTNUM/blob/main/model-image.png?raw=true", label="Model Image") |
|
|
|
|
|
with gr.Row(): |
|
model_choice = gr.Radio( |
|
choices=["SaProt", "ESM2"], |
|
label="Select PLTNUM's base model.", |
|
value="SaProt" |
|
) |
|
organism_choice = gr.Radio( |
|
choices=["Mouse", "Human"], |
|
label="Select the target organism.", |
|
value="Mouse" |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Upload PDB File"): |
|
gr.Markdown("### Upload your PDB file:") |
|
pdb_file = gr.File(label="Upload PDB File") |
|
|
|
predict_button = gr.Button("Predict Stability") |
|
prediction_output = gr.Textbox(label="Stability Prediction", interactive=False) |
|
|
|
predict_button.click(fn=predict_stability, inputs=[model_choice, organism_choice, pdb_file], outputs=prediction_output) |
|
|
|
with gr.TabItem("Enter Protein Sequence"): |
|
gr.Markdown("### Enter the protein sequence:") |
|
sequence = gr.Textbox( |
|
label="Protein Sequence", |
|
placeholder="Enter your protein sequence here...", |
|
lines=8, |
|
) |
|
predict_button = gr.Button("Predict Stability") |
|
prediction_output = gr.Textbox(label="Stability Prediction", interactive=False) |
|
|
|
predict_button.click(fn=predict_stability, inputs=[model_choice, organism_choice, sequence], outputs=prediction_output) |
|
|
|
gr.Markdown( |
|
""" |
|
### How to Use: |
|
- **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction. |
|
- **Select Organism**: Choose between 'Mouse' or 'Human'. |
|
- **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file. |
|
- **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence. |
|
- **Predict**: Click 'Predict Stability' to receive the prediction. |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### About the Tool |
|
This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input. |
|
""" |
|
) |
|
|
|
demo.launch() |
|
|