|
import sys |
|
import random |
|
import os |
|
import pandas as pd |
|
import torch |
|
import itertools |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoTokenizer |
|
|
|
sys.path.append("scripts/") |
|
from foldseek_util import get_struc_seq |
|
from utils import seed_everything |
|
from models import PLTNUM_PreTrainedModel |
|
from datasets_ import PLTNUMDataset |
|
|
|
|
|
class Config: |
|
batch_size = 2 |
|
use_amp = False |
|
num_workers = 1 |
|
max_length = 512 |
|
used_sequence = "left" |
|
padding_side = "right" |
|
task = "classification" |
|
sequence_col = "sequence" |
|
seed = 42 |
|
|
|
|
|
|
|
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()): |
|
|
|
if pdb_file: |
|
pdb_path = pdb_file.name |
|
os.system("chmod 777 bin/foldseek") |
|
sequences = get_foldseek_seq(pdb_path) |
|
if not sequences: |
|
return "Failed to extract sequence from the PDB file." |
|
if model_choice == "SaProt": |
|
sequence = sequences[2] |
|
else: |
|
sequence = sequences[0] |
|
|
|
if organism_choice == "Human": |
|
cell_line = "HeLa" |
|
else: |
|
cell_line = "NIH3T3" |
|
|
|
if sequence: |
|
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
cfg.architecture = model_choice |
|
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
output = predict(cfg, sequence) |
|
return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..." |
|
else: |
|
return "No valid input provided." |
|
|
|
|
|
def get_foldseek_seq(pdb_path): |
|
parsed_seqs = get_struc_seq( |
|
"bin/foldseek", |
|
pdb_path, |
|
["A"], |
|
process_id=random.randint(0, 10000000), |
|
)["A"] |
|
return parsed_seqs |
|
|
|
|
|
def predict(cfg, sequence): |
|
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 |
|
cfg.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if cfg.used_sequence == "both": |
|
cfg.max_length += 1 |
|
|
|
seed_everything(cfg.seed) |
|
df = pd.DataFrame({cfg.sequence_col: [sequence]}) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
cfg.model_path, padding_side=cfg.padding_side |
|
) |
|
cfg.tokenizer = tokenizer |
|
dataset = PLTNUMDataset(cfg, df, train=False) |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=cfg.batch_size, |
|
shuffle=False, |
|
num_workers=cfg.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) |
|
model.to(cfg.device) |
|
|
|
|
|
model.eval() |
|
predictions = [] |
|
|
|
for inputs, _ in dataloader: |
|
inputs = inputs.to(cfg.device) |
|
with torch.no_grad(): |
|
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): |
|
preds = ( |
|
torch.sigmoid(model(inputs)) |
|
if cfg.task == "classification" |
|
else model(inputs) |
|
) |
|
predictions += preds.cpu().tolist() |
|
outputs = {} |
|
predictions = list(itertools.chain.from_iterable(predictions)) |
|
outputs["raw prediction values"] = predictions |
|
outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions] |
|
return outputs |
|
|
|
predict_stability("SaProt", "Human", sequence="MELKQK") |