PLTNUM / scripts /train.py
sagawa's picture
Upload 17 files
4321e7e verified
import gc
import os
import sys
import time
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, r2_score
from sklearn.model_selection import StratifiedKFold
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
sys.path.append(".")
from utils import AverageMeter, get_logger, seed_everything, timeSince
from datasets import PLTNUMDataset, LSTMDataset
from models import PLTNUM, LSTMModel
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)
def parse_args():
parser = argparse.ArgumentParser(
description="Training script for protein half-life prediction."
)
parser.add_argument(
"--data_path",
type=str,
required=True,
help="Path to the training data.",
)
parser.add_argument(
"--model",
type=str,
default="westlake-repl/SaProt_650M_AF2",
help="Pretrained model name or path.",
)
parser.add_argument(
"--architecture",
type=str,
default="SaProt",
help="Model architecture: 'ESM2', 'SaProt', or 'LSTM'.",
)
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate.")
parser.add_argument(
"--epochs",
type=int,
default=5,
help="Number of training epochs.",
)
parser.add_argument("--batch_size", type=int, default=4, help="Batch size.")
parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for reproducibility.",
)
parser.add_argument(
"--use_amp",
action="store_true",
default=False,
help="Use AMP for mixed precision training.",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers for data loading.",
)
parser.add_argument(
"--max_length",
type=int,
default=512,
help="Maximum input sequence length. Two tokens are used fo <cls> and <eos> tokens. So the actual length of input sequence is max_length - 2. Padding or truncation is applied to make the length of input sequence equal to max_length.",
)
parser.add_argument(
"--used_sequence",
type=str,
default="left",
help="Which part of the sequence to use: 'left', 'right', 'both', or 'internal'.",
)
parser.add_argument(
"--padding_side",
type=str,
default="right",
help="Padding side: 'right' or 'left'.",
)
parser.add_argument(
"--mask_ratio",
type=float,
default=0.05,
help="Ratio of mask tokens for augmentation.",
)
parser.add_argument(
"--mask_prob",
type=float,
default=0.2,
help="Probability to apply mask augmentation",
)
parser.add_argument(
"--random_delete_ratio",
type=float,
default=0.1,
help="Ratio of deleting tokens in augmentation.",
)
parser.add_argument(
"--random_delete_prob",
type=float,
default=-1,
help="Probability to apply random delete augmentation.",
)
parser.add_argument(
"--random_change_ratio",
type=float,
default=0,
help="Ratio of changing tokens in augmentation.",
)
parser.add_argument(
"--truncate_augmentation_prob",
type=float,
default=-1,
help="Probability to apply truncate augmentation.",
)
parser.add_argument(
"--n_folds",
type=int,
default=10,
help="Number of folds for cross-validation.",
)
parser.add_argument(
"--print_freq",
type=int,
default=300,
help="Log print frequency.",
)
parser.add_argument(
"--freeze_layer",
type=int,
default=-1,
help="Freeze layers of the model. -1 means no layers are frozen.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./output",
help="Output directory.",
)
parser.add_argument(
"--task",
type=str,
default="classification",
help="Task type: 'classification' or 'regression'.",
)
parser.add_argument(
"--target_col",
type=str,
default="Protein half-life average [h]",
help="Column name of the target.",
)
parser.add_argument(
"--sequence_col",
type=str,
default="aa_foldseek",
help="Column name fot the input sequence.",
)
return parser.parse_args()
def train_fn(train_loader, model, criterion, optimizer, epoch, cfg):
model.train()
scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)
losses = AverageMeter()
label_list, pred_list = [], []
start = time.time()
for step, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
labels = (
labels.float()
if cfg.task == "classification"
else labels.to(dtype=torch.half)
)
batch_size = labels.size(0)
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
y_preds = model(inputs)
loss = criterion(y_preds, labels.view(-1, 1))
losses.update(loss.item(), batch_size)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
label_list += labels.tolist()
pred_list += y_preds.tolist()
if step % cfg.print_freq == 0 or step == len(train_loader) - 1:
if cfg.task == "classification":
pred_list_new = (torch.Tensor(pred_list) > 0.5).to(dtype=torch.long)
acc = accuracy_score(label_list, pred_list_new > 0.5)
cfg.logger.info(
f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] "
f"Elapsed {timeSince(start, float(step + 1) / len(train_loader))} "
f"Loss: {losses.val:.4f}({losses.avg:.4f}) "
f"LR: {optimizer.param_groups[0]['lr']:.8f} "
f"Accuracy: {acc:.4f}"
)
elif cfg.task == "regression":
r2 = r2_score(label_list, pred_list)
cfg.logger.info(
f"Epoch: [{epoch + 1}][{step}/{len(train_loader)}] "
f"Elapsed {timeSince(start, float(step + 1) / len(train_loader))} "
f"Loss: {losses.val:.4f}({losses.avg:.4f}) "
f"R2 Score: {r2:.4f} "
f"LR: {optimizer.param_groups[0]['lr']:.8f}"
)
if cfg.task == "classification":
pred_list_new = (torch.Tensor(pred_list) > 0.5).to(dtype=torch.long)
acc = accuracy_score(label_list, pred_list_new)
return losses.avg, acc
elif cfg.task == "regression":
return losses.avg, r2_score(label_list, pred_list)
def valid_fn(valid_loader, model, criterion, cfg):
losses = AverageMeter()
model.eval()
label_list, pred_list = [], []
start = time.time()
for step, (inputs, labels) in enumerate(valid_loader):
inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
labels = (
labels.float()
if cfg.task == "classification"
else labels.to(dtype=torch.half)
)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
y_preds = (
torch.sigmoid(model(inputs))
if cfg.task == "classification"
else model(inputs)
)
loss = criterion(y_preds, labels.view(-1, 1))
losses.update(loss.item(), labels.size(0))
label_list += labels.tolist()
pred_list += y_preds.tolist()
if step % cfg.print_freq == 0 or step == len(valid_loader) - 1:
if cfg.task == "classification":
pred_list_new = (torch.Tensor(pred_list) > 0.5).to(dtype=torch.long)
acc = accuracy_score(label_list, pred_list_new > 0.5)
f1 = f1_score(label_list, pred_list_new, average="macro")
cfg.logger.info(
f"EVAL: [{step}/{len(valid_loader)}] "
f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader))} "
f"Loss: {losses.val:.4f}({losses.avg:.4f}) "
f"Accuracy: {acc:.4f} "
f"F1 Score: {f1:.4f}"
)
elif cfg.task == "regression":
r2 = r2_score(label_list, pred_list)
cfg.logger.info(
f"EVAL: [{step}/{len(valid_loader)}] "
f"Elapsed {timeSince(start, float(step + 1) / len(valid_loader))} "
f"Loss: {losses.val:.4f}({losses.avg:.4f}) "
f"R2 Score: {r2:.4f}"
)
if cfg.task == "classification":
pred_list_new = (torch.Tensor(pred_list) > 0.5).to(dtype=torch.long)
return (
f1_score(label_list, pred_list_new, average="macro"),
accuracy_score(label_list, pred_list_new),
pred_list,
)
elif cfg.task == "regression":
return losses.avg, r2_score(label_list, pred_list), np.array(pred_list)
def train_loop(folds, fold, cfg):
cfg.logger.info(f"================== fold: {fold} training ======================")
train_folds = folds[folds["fold"] != fold].reset_index(drop=True)
valid_folds = folds[folds["fold"] == fold].reset_index(drop=True)
if cfg.architecture in ["ESM2", "SaProt"]:
train_dataset = PLTNUMDataset(cfg, train_folds, train=True)
valid_dataset = PLTNUMDataset(cfg, valid_folds, train=False)
elif cfg.architecture == "LSTM":
train_dataset = LSTMDataset(cfg, train_folds, train=True)
valid_dataset = LSTMDataset(cfg, valid_folds, train=False)
train_loader = DataLoader(
train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
if cfg.architecture in ["ESM2", "SaProt"]:
model = PLTNUM(cfg)
if cfg.freeze_layer >= 0:
for name, param in model.named_parameters():
if f"model.encoder.layer.{cfg.freeze_layer}" in name:
break
param.requires_grad = False
model.config.save_pretrained(cfg.output_dir)
elif cfg.architecture == "LSTM":
model = LSTMModel(cfg)
model.to(cfg.device)
optimizer = Adam(model.parameters(), lr=cfg.lr)
if cfg.architecture in ["ESM2", "SaProt"]:
scheduler = CosineAnnealingLR(
optimizer,
**{"T_max": 2, "eta_min": 1.0e-6, "last_epoch": -1},
)
elif cfg.architecture == "LSTM":
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=0, num_training_steps=cfg.epochs, num_cycles=0.5
)
criterion = nn.BCEWithLogitsLoss() if cfg.task == "classification" else nn.MSELoss()
best_score = 0 if cfg.task == "classification" else float("inf")
for epoch in range(cfg.epochs):
start_time = time.time()
# train
avg_loss, train_score = train_fn(
train_loader, model, criterion, optimizer, epoch, cfg
)
scheduler.step()
# eval
val_score, val_score2, predictions = valid_fn(
valid_loader, model, criterion, cfg
)
elapsed = time.time() - start_time
if cfg.task == "classification":
cfg.logger.info(
f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} train_acc: {train_score:.4f} valid_acc: {val_score2:.4f} valid_f1: {val_score:.4f} time: {elapsed:.0f}s"
)
elif cfg.task == "regression":
cfg.logger.info(
f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} train_r2: {train_score:.4f} valid_r2: {val_score2:.4f} valid_loss: {val_score:.4f} time: {elapsed:.0f}s"
)
if (cfg.task == "classification" and best_score < val_score) or (
cfg.task == "regression" and best_score > val_score
):
best_score = val_score
cfg.logger.info(f"Epoch {epoch+1} - Save Best Score: {val_score:.4f} Model")
torch.save(
predictions,
os.path.join(cfg.output_dir, f"predictions.pth"),
)
torch.save(
model.state_dict(),
os.path.join(cfg.output_dir, f"model_fold{fold}.pth"),
)
predictions = torch.load(
os.path.join(cfg.output_dir, f"predictions.pth"), map_location="cpu"
)
valid_folds["prediction"] = predictions
cfg.logger.info(f"[Fold{fold}] Best score: {best_score}")
torch.cuda.empty_cache()
gc.collect()
return valid_folds
def get_embedding(folds, fold, path, cfg):
valid_folds = folds[folds["fold"] == fold].reset_index(drop=True)
valid_dataset = PLTNUMDataset(cfg, valid_folds, train=False)
valid_loader = DataLoader(
valid_dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
model = PLTNUM(cfg)
model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
model.to(device)
model.eval()
embedding_list = []
for inputs, _ in valid_loader:
inputs = inputs.to(device)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
embedding = model.create_embedding(inputs)
embedding_list += embedding.tolist()
torch.cuda.empty_cache()
gc.collect()
return embedding_list
if __name__ == "__main__":
config = parse_args()
config.token_length = 2 if config.architecture == "SaProt" else 1
config.device = device
if not os.path.exists(config.output_dir):
os.makedirs(config.output_dir)
if config.used_sequence == "both":
config.max_length += 1
LOGGER = get_logger(os.path.join(config.output_dir, "output"))
config.logger = LOGGER
seed_everything(config.seed)
train_df = (
pd.read_csv(config.data_path)
.drop_duplicates(subset=[config.sequence_col], keep="first")
.reset_index(drop=True)
)
train_df["T1/2 [h]"] = train_df[config.target_col]
if config.task == "classification":
train_df["target"] = (
train_df["T1/2 [h]"] > np.median(train_df["T1/2 [h]"])
).astype(int)
train_df["class"] = train_df["target"]
elif config.task == "regression":
train_df["log1p(T1/2 [h])"] = np.log1p(train_df["T1/2 [h]"])
train_df["log1p(T1/2 [h])"] = (
train_df["log1p(T1/2 [h])"] - min(train_df["log1p(T1/2 [h])"])
) / (max(train_df["log1p(T1/2 [h])"]) - min(train_df["log1p(T1/2 [h])"]))
train_df["target"] = train_df["log1p(T1/2 [h])"]
def get_class(row, class_num=5):
denom = 1 / class_num
num = row["log1p(T1/2 [h])"]
for target in range(class_num):
if denom * target <= num and num < denom * (target + 1):
break
row["class"] = target
return row
train_df = train_df.apply(get_class, axis=1)
train_df["fold"] = -1
kf = StratifiedKFold(
n_splits=config.n_folds, shuffle=True, random_state=config.seed
)
for fold, (trn_ind, val_ind) in enumerate(kf.split(train_df, train_df["class"])):
train_df.loc[val_ind, "fold"] = int(fold)
if config.architecture in ["ESM2", "SaProt"]:
tokenizer = AutoTokenizer.from_pretrained(
config.model, padding_side=config.padding_side
)
tokenizer.save_pretrained(config.output_dir)
config.tokenizer = tokenizer
oof_df = pd.DataFrame()
for fold in range(config.n_folds):
_oof_df = train_loop(train_df, fold, config)
oof_df = pd.concat([oof_df, _oof_df], axis=0)
oof_df = oof_df.reset_index(drop=True)
oof_df.to_csv(os.path.join(config.output_dir, "oof_df.csv"), index=False)