File size: 3,514 Bytes
178b187 33e1d55 9329e39 e4d81ca 9329e39 a418c62 2683162 ef0bf75 f717639 ef0bf75 9329e39 e4d81ca 93a77af ef0bf75 93a77af 490a4c0 4d6e83e 33e1d55 9329e39 5c90d58 5ca8306 ef0bf75 93a77af 178b187 64c1665 178b187 9329e39 e4d81ca 9329e39 e4d81ca 9329e39 e4d81ca 9329e39 e4d81ca 9329e39 e4d81ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import sys
import random
import os
import pandas as pd
import torch
import itertools
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
sys.path.append("scripts/")
from foldseek_util import get_struc_seq
from utils import seed_everything
from models import PLTNUM_PreTrainedModel
from datasets_ import PLTNUMDataset
class Config:
batch_size = 2
use_amp = False
num_workers = 1
max_length = 512
used_sequence = "left"
padding_side = "right"
task = "classification"
sequence_col = "sequence"
seed = 42
# Assuming 'predict_stability' is your function that predicts protein stability
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
# Check if pdb_file is provided
if pdb_file:
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
os.system("chmod 777 bin/foldseek")
sequences = get_foldseek_seq(pdb_path)
if not sequences:
return "Failed to extract sequence from the PDB file."
if model_choice == "SaProt":
sequence = sequences[2]
else:
sequence = sequences[0]
if organism_choice == "Human":
cell_line = "HeLa"
else:
cell_line = "NIH3T3"
# If sequence is provided directly
if sequence:
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.architecture = model_choice
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
output = predict(cfg, sequence)
return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..."
else:
return "No valid input provided."
def get_foldseek_seq(pdb_path):
parsed_seqs = get_struc_seq(
"bin/foldseek",
pdb_path,
["A"],
process_id=random.randint(0, 10000000),
)["A"]
return parsed_seqs
def predict(cfg, sequence):
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
if cfg.used_sequence == "both":
cfg.max_length += 1
seed_everything(cfg.seed)
df = pd.DataFrame({cfg.sequence_col: [sequence]})
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, padding_side=cfg.padding_side
)
cfg.tokenizer = tokenizer
dataset = PLTNUMDataset(cfg, df, train=False)
dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
model.to(cfg.device)
# predictions = predict_fn(loader, model, cfg)
model.eval()
predictions = []
for inputs, _ in dataloader:
inputs = inputs.to(cfg.device)
with torch.no_grad():
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
preds = (
torch.sigmoid(model(inputs))
if cfg.task == "classification"
else model(inputs)
)
predictions += preds.cpu().tolist()
outputs = {}
predictions = list(itertools.chain.from_iterable(predictions))
outputs["raw prediction values"] = predictions
outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
return outputs
predict_stability("SaProt", "Human", sequence="MELKQK") |