|
import gc |
|
import os |
|
import sys |
|
import argparse |
|
|
|
import pandas as pd |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoTokenizer |
|
|
|
sys.path.append(".") |
|
from utils import seed_everything |
|
from models import PLTNUM |
|
from datasets import PLTNUMDataset |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Prediction script for protein sequence classification/regression." |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to the input data.", |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="westlake-repl/SaProt_650M_AF2", |
|
help="Pretrained model name or path.", |
|
) |
|
parser.add_argument( |
|
"--architecture", |
|
type=str, |
|
default="SaProt", |
|
help="Model architecture: 'ESM2', 'SaProt', or 'LSTM'.", |
|
) |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
required=True, |
|
help="Path to the model for prediction.", |
|
) |
|
parser.add_argument("--batch_size", type=int, default=4, help="Batch size.") |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=42, |
|
help="Seed for reproducibility.", |
|
) |
|
parser.add_argument( |
|
"--use_amp", |
|
action="store_true", |
|
default=False, |
|
help="Use AMP for mixed precision prediction.", |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=4, |
|
help="Number of workers for data loading.", |
|
) |
|
parser.add_argument( |
|
"--max_length", |
|
type=int, |
|
default=512, |
|
help="Maximum input sequence length. Two tokens are used fo <cls> and <eos> tokens. So the actual length of input sequence is max_length - 2. Padding or truncation is applied to make the length of input sequence equal to max_length.", |
|
) |
|
parser.add_argument( |
|
"--used_sequence", |
|
type=str, |
|
default="left", |
|
help="Which part of the sequence to use: 'left', 'right', 'both', or 'internal'.", |
|
) |
|
parser.add_argument( |
|
"--padding_side", |
|
type=str, |
|
default="right", |
|
help="Padding side: 'right' or 'left'.", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="./output", |
|
help="Output directory.", |
|
) |
|
parser.add_argument( |
|
"--task", |
|
type=str, |
|
default="classification", |
|
help="Task type: 'classification' or 'regression'.", |
|
) |
|
parser.add_argument( |
|
"--sequence_col", |
|
type=str, |
|
default="aa_foldseek", |
|
help="Column name fot the input sequence.", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def predict_fn(valid_loader, model, cfg): |
|
model.eval() |
|
predictions = [] |
|
|
|
for inputs, _ in valid_loader: |
|
inputs = inputs.to(cfg.device) |
|
with torch.no_grad(): |
|
with torch.cuda.amp.autocast(enabled=cfg.use_amp): |
|
preds = ( |
|
torch.sigmoid(model(inputs)) |
|
if cfg.task == "classification" |
|
else model(inputs) |
|
) |
|
predictions += preds.cpu().tolist() |
|
|
|
return predictions |
|
|
|
|
|
def predict(folds, model_path, cfg): |
|
dataset = PLTNUMDataset(cfg, folds, train=False) |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=cfg.batch_size, |
|
shuffle=False, |
|
num_workers=cfg.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
model = PLTNUM(cfg) |
|
model.load_state_dict(torch.load(model_path, map_location=cfg.device)) |
|
model.to(cfg.device) |
|
|
|
predictions = predict_fn(loader, model, cfg) |
|
|
|
folds["raw prediction values"] = predictions |
|
if cfg.task == "classification": |
|
folds["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions] |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
return folds |
|
|
|
|
|
if __name__ == "__main__": |
|
config = parse_args() |
|
config.token_length = 2 if config.architecture == "SaProt" else 1 |
|
config.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if not os.path.exists(config.output_dir): |
|
os.makedirs(config.output_dir) |
|
|
|
if config.used_sequence == "both": |
|
config.max_length += 1 |
|
|
|
seed_everything(config.seed) |
|
|
|
df = pd.read_csv(config.data_path) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
config.model, padding_side=config.padding_side |
|
) |
|
config.tokenizer = tokenizer |
|
|
|
result = predict(df, config.model_path, config) |
|
result.to_csv(os.path.join(config.output_dir, "result.csv"), index=False) |
|
|