import gradio as gr import sys import random import os import pandas as pd import torch import itertools from torch.utils.data import DataLoader from transformers import AutoTokenizer import shap sys.path.append("scripts/") from foldseek_util import get_struc_seq from utils import seed_everything, save_pickle from models import PLTNUM_PreTrainedModel from datasets_ import PLTNUMDataset class Config: def __init__(self): self.batch_size = 2 self.use_amp = False self.num_workers = 1 self.max_length = 512 self.used_sequence = "left" self.padding_side = "right" self.task = "classification" self.sequence_col = "sequence" self.seed = 42 self.max_evals = 10 def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): results = {"file_name": [], "raw prediction value": [], "binary prediction value": [] } file_names = [] input_sequences = [] for pdb_file in pdb_files: pdb_path = pdb_file.name os.system("chmod 777 bin/foldseek") sequences = get_foldseek_seq(pdb_path) if not sequences: results["file_name"].append(pdb_file.name.split("/")[-1]) results["raw prediction value"].append(None) results["binary prediction value"].append(None) continue sequence = sequences[2] if model_choice == "SaProt" else sequences[0] file_names.append(pdb_file.name.split("/")[-1]) input_sequences.append(sequence) raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, input_sequences, cfg) results["file_name"] = results["file_name"] + file_names results["raw prediction value"] = results["raw prediction value"] + raw_prediction results["binary prediction value"] = results["binary prediction value"] + binary_prediction df = pd.DataFrame(results) output_csv = "/tmp/predictions.csv" df.to_csv(output_csv, index=False) return output_csv def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()): try: if not sequence: return "No valid sequence provided." raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, [sequence], cfg) df = pd.DataFrame({"sequence": sequence, "raw prediction value": raw_prediction, "binary prediction value": binary_prediction}) output_csv = "/tmp/predictions.csv" df.to_csv(output_csv, index=False) return output_csv except Exception as e: return f"An error occurred: {str(e)}" def predict_stability_core(model_choice, organism_choice, sequences, cfg=Config()): cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3" cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.architecture = model_choice cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" output = predict(cfg, sequences) return output def get_foldseek_seq(pdb_path): parsed_seqs = get_struc_seq( "bin/foldseek", pdb_path, ["A"], process_id=random.randint(0, 10000000), )["A"] return parsed_seqs def predict(cfg, sequences): cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 cfg.device = "cuda" if torch.cuda.is_available() else "cpu" if cfg.used_sequence == "both": cfg.max_length += 1 seed_everything(cfg.seed) df = pd.DataFrame({cfg.sequence_col: sequences}) tokenizer = AutoTokenizer.from_pretrained( cfg.model_path, padding_side=cfg.padding_side ) cfg.tokenizer = tokenizer dataset = PLTNUMDataset(cfg, df, train=False) dataloader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=False, ) model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) model.to(cfg.device) model.eval() predictions = [] for inputs, _ in dataloader: inputs = inputs.to(cfg.device) with torch.no_grad(): with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): preds = ( torch.sigmoid(model(inputs)) if cfg.task == "classification" else model(inputs) ) predictions += preds.cpu().tolist() predictions = list(itertools.chain.from_iterable(predictions)) return predictions, [1 if x > 0.5 else 0 for x in predictions] def calculate_shap_values_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): input_sequences = [] for pdb_file in pdb_files: pdb_path = pdb_file.name os.system("chmod 777 bin/foldseek") sequences = get_foldseek_seq(pdb_path) sequence = sequences[2] if model_choice == "SaProt" else sequences[0] input_sequences.append(sequence) shap_values = calculate_shap_values_core(model_choice, organism_choice, input_sequences, cfg) output_path = "/tmp/shap_values.pkl" save_pickle( output_path, shap_values ) return output_path def calculate_shap_fn(texts, model, cfg): if len(texts) == 1: texts = texts[0] else: texts = texts.tolist() inputs = cfg.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=cfg.max_length, ) inputs = {k: v.to(cfg.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(inputs) outputs = torch.sigmoid(outputs).detach().cpu().numpy() return outputs def calculate_shap_values_core(model_choice, organism_choice, sequences, cfg=Config()): cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3" cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.architecture = model_choice cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.device = "cuda" if torch.cuda.is_available() else "cpu" seed_everything(cfg.seed) tokenizer = AutoTokenizer.from_pretrained( cfg.model_path, padding_side=cfg.padding_side ) cfg.tokenizer = tokenizer model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg).to(cfg.device) model.eval() # build an explainer using a token masker explainer = shap.Explainer(lambda x: calculate_shap_fn(x, model, cfg), cfg.tokenizer) shap_values = explainer( sequences, batch_size=cfg.batch_size, max_evals=cfg.max_evals, ) return shap_values # Gradio Interface with gr.Blocks() as demo: gr.Markdown( """ # PLTNUM: Protein LifeTime Neural Model **Predict the protein half-life from its sequence or PDB file.** """ ) gr.Image( "https://raw.githubusercontent.com/sagawatatsuya/PLTNUM/main/model-image.png", label="Model Image", ) # Model and Organism selection in the same row to avoid layout issues with gr.Row(): model_choice = gr.Radio( choices=["SaProt", "ESM2"], label="Select PLTNUM's base model.", value="SaProt", ) organism_choice = gr.Radio( choices=["Mouse", "Human"], label="Select the target organism.", value="Mouse", ) with gr.Tabs(): with gr.TabItem("Upload PDB File"): gr.Markdown("### Upload your PDB files:") pdb_files = gr.File(label="Upload PDB Files", file_count="multiple") predict_button = gr.Button("Predict Stability") prediction_output = gr.File( label="Download Predictions" ) predict_button.click( fn=predict_stability_with_pdb, inputs=[model_choice, organism_choice, pdb_files], outputs=prediction_output, ) calculate_shap_values_button = gr.Button("Calculate SHAP Values") shap_values_output = gr.File( label="Download SHAP Values" ) calculate_shap_values_button.click( fn=calculate_shap_values_with_pdb, inputs=[model_choice, organism_choice, pdb_files], outputs=shap_values_output, ) with gr.TabItem("Enter Protein Sequence"): gr.Markdown("### Enter the protein sequence:") sequence = gr.Textbox( label="Protein Sequence", placeholder="Enter your protein sequence here...", lines=8, ) predict_button = gr.Button("Predict Stability") prediction_output = gr.File( label="Download Predictions" ) predict_button.click( fn=predict_stability_with_sequence, inputs=[model_choice, organism_choice, sequence], outputs=prediction_output, ) gr.Markdown( """ ### How to Use: - **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction. - **Select Organism**: Choose between 'Mouse' or 'Human'. - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file. - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence. - **Predict**: Click 'Predict Stability' to receive the prediction. """ ) gr.Markdown( """ ### About the Tool This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input. """ ) demo.launch()