FusionGDA / src /finetune /.ipynb_checkpoints /finetune-checkpoint.py
ZhaohanM
FusionGDA
a1af661
raw
history blame
17.2 kB
import argparse
import os
import random
import string
import sys
import pandas as pd
from datetime import datetime
sys.path.append("../")
import numpy as np
import torch
import lightgbm as lgb
import sklearn.metrics as metrics
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel
from utils.downstream_disgenet import DisGeNETProcessor
from utils.metric_learning_models import GDA_Metric_Learning
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument("--step", type=int, default=0)
parser.add_argument(
"--save_model_path",
type=str,
default=None,
help="path of the pretrained disease model located",
)
parser.add_argument(
"--prot_encoder_path",
type=str,
default="facebook/esm2_t33_650M_UR50D",
#"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D”
help="path/name of protein encoder model located",
)
parser.add_argument(
"--disease_encoder_path",
type=str,
default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
help="path/name of textual pre-trained language model",
)
parser.add_argument("--reduction_factor", type=int, default=8)
parser.add_argument(
"--loss",
help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}",
default="infoNCE",
)
parser.add_argument(
"--input_feature_save_path",
type=str,
default="../../data/processed_disease",
help="path of tokenized training data",
)
parser.add_argument(
"--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}"
)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--num_leaves", type=int, default=5)
parser.add_argument("--max_depth", type=int, default=5)
parser.add_argument("--lr", type=float, default=0.35)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--use_miner", action="store_true")
parser.add_argument("--miner_margin", default=0.2, type=float)
parser.add_argument("--freeze_prot_encoder", action="store_true")
parser.add_argument("--freeze_disease_encoder", action="store_true")
parser.add_argument("--use_adapter", action="store_true")
parser.add_argument("--use_pooled", action="store_true")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument(
"--use_both_feature",
help="use the both features of gnn_feature_v1_samples and pretrained models",
action="store_true",
)
parser.add_argument(
"--use_v1_feature_only",
help="use the features of gnn_feature_v1_samples only",
action="store_true",
)
parser.add_argument(
"--save_path_prefix",
type=str,
default="../../save_model_ckp/finetune/",
help="save the result in which directory",
)
parser.add_argument(
"--save_name", default="fine_tune", type=str, help="the name of the saved file"
)
# Add argument for input CSV file path
parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.")
# Add argument for output CSV file path
parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.")
return parser.parse_args()
def get_feature(model, dataloader, args):
x = list()
y = list()
with torch.no_grad():
for step, batch in tqdm(enumerate(dataloader)):
prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch
prot_input = {
'input_ids': prot_input_ids.to(args.device),
'attention_mask': prot_attention_mask.to(args.device)
}
dis_input = {
'input_ids': dis_input_ids.to(args.device),
'attention_mask': dis_attention_mask.to(args.device)
}
feature_output = model.predict(prot_input, dis_input)
x1 = feature_output.cpu().numpy()
x.append(x1)
y.append(y1.cpu().numpy())
x = np.concatenate(x, axis=0)
y = np.concatenate(y, axis=0)
return x, y
def encode_pretrained_feature(args, disGeNET):
input_feat_file = os.path.join(
args.input_feature_save_path,
f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz",
)
if os.path.exists(input_feat_file):
print(f"load prior feature data from {input_feat_file}.")
loaded = np.load(input_feat_file)
x_train, y_train = loaded["x_train"], loaded["y_train"]
x_valid, y_valid = loaded["x_valid"], loaded["y_valid"]
# x_test, y_test = loaded["x_test"], loaded["y_test"]
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
# prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
print("prot_tokenizer", len(prot_tokenizer))
disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
print("disease_tokenizer", len(disease_tokenizer))
prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
# prot_model = BertModel.from_pretrained(args.prot_encoder_path)
disease_model = BertModel.from_pretrained(args.disease_encoder_path)
if args.save_model_path:
model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
if args.use_adapter:
prot_model_path = os.path.join(
args.save_model_path, f"prot_adapter_step_{args.step}"
)
disease_model_path = os.path.join(
args.save_model_path, f"disease_adapter_step_{args.step}"
)
model.load_adapters(prot_model_path, disease_model_path)
else:
prot_model_path = os.path.join(
args.save_model_path, f"step_{args.step}_model.bin"
)# , f"step_{args.step}_model.bin"
disease_model_path = os.path.join(
args.save_model_path, f"step_{args.step}_model.bin"
)
model.non_adapters(prot_model_path, disease_model_path)
model = model.to(args.device)
prot_model = model.prot_encoder
disease_model = model.disease_encoder
print(f"loaded prior model {args.save_model_path}.")
def collate_fn_batch_encoding(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = disease_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
test_examples = disGeNET.get_test_examples(args.test)
print(f"get test examples: {len(test_examples)}")
test_dataloader = DataLoader(
test_examples,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn_batch_encoding,
)
print( f"dataset loaded: test-{len(test_examples)}")
x_test, y_test = get_feature(model, test_dataloader, args)
else:
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
# prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
print("prot_tokenizer", len(prot_tokenizer))
disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
print("disease_tokenizer", len(disease_tokenizer))
prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
# prot_model = BertModel.from_pretrained(args.prot_encoder_path)
disease_model = BertModel.from_pretrained(args.disease_encoder_path)
if args.save_model_path:
model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
if args.use_adapter:
prot_model_path = os.path.join(
args.save_model_path, f"prot_adapter_step_{args.step}"
)
disease_model_path = os.path.join(
args.save_model_path, f"disease_adapter_step_{args.step}"
)
model.load_adapters(prot_model_path, disease_model_path)
else:
prot_model_path = os.path.join(
args.save_model_path, f"step_{args.step}_model.bin"
)# , f"step_{args.step}_model.bin"
disease_model_path = os.path.join(
args.save_model_path, f"step_{args.step}_model.bin"
)
model.non_adapters(prot_model_path, disease_model_path)
model = model.to(args.device)
prot_model = model.prot_encoder
disease_model = model.disease_encoder
print(f"loaded prior model {args.save_model_path}.")
def collate_fn_batch_encoding(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = disease_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
train_examples = disGeNET.get_train_examples(args.test)
print(f"get training examples: {len(train_examples)}")
valid_examples = disGeNET.get_val_examples(args.test)
print(f"get validation examples: {len(valid_examples)}")
test_examples = disGeNET.get_test_examples(args.test)
print(f"get test examples: {len(test_examples)}")
train_dataloader = DataLoader(
train_examples,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn_batch_encoding,
)
valid_dataloader = DataLoader(
valid_examples,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn_batch_encoding,
)
test_dataloader = DataLoader(
test_examples,
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_fn_batch_encoding,
)
print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}")
x_train, y_train = get_feature(model, train_dataloader, args)
x_valid, y_valid = get_feature(model, valid_dataloader, args)
x_test, y_test = get_feature(model, test_dataloader, args)
# Save input feature to reduce encoding time
np.savez_compressed(
input_feat_file,
x_train=x_train,
y_train=y_train,
x_valid=x_valid,
y_valid=y_valid,
)
print(f"save input feature into {input_feat_file}")
# Save input feature to reduce encoding time
return x_train, y_train, x_valid, y_valid, x_test, y_test
def train(args):
# defining parameters
if args.save_model_path:
args.model_short = (
args.save_model_path.split("/")[-1]
)
print(f"model name {args.model_short}")
else:
args.model_short = (
args.disease_encoder_path.split("/")[-1]
)
print(f"model name {args.model_short}")
# disGeNET = DisGeNETProcessor()
disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path)
x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET)
print("train: ", x_train.shape, y_train.shape)
print("valid: ", x_valid.shape, y_valid.shape)
print("test: ", x_test.shape, y_test.shape)
params = {
"task": "train", # "predict" train
"boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"."
"objective": "binary",
"num_leaves": args.num_leaves,
"early_stopping_round": 30,
"max_depth": args.max_depth,
"learning_rate": args.lr,
"metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc"
"verbose": 1,
}
lgb_train = lgb.Dataset(x_train, y_train)
lgb_valid = lgb.Dataset(x_valid, y_valid)
lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
# fitting the model
model = lgb.train(
params, train_set=lgb_train, valid_sets=lgb_valid)
# prediction
valid_y_pred = model.predict(x_valid)
test_y_pred = model.predict(x_test)
# predict liver fibrosis
predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"])
# data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv')
data_test = pd.read_csv(args.input_csv_path)
predictions = pd.concat([data_test, predictions_df], axis=1)
# filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714']
predictions.sort_values(by='Prediction_score', ascending=False, inplace=True)
top_100_predictions = predictions.head(100)
top_100_predictions.to_csv(args.output_csv_path, index=False)
# Accuracy
y_pred = model.predict(x_test, num_iteration=model.best_iteration)
y_pred[y_pred >= 0.5] = 1
y_pred[y_pred < 0.5] = 0
accuracy = accuracy_score(y_test, y_pred)
# AUC
valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred)
valid_average_precision_score = metrics.average_precision_score(
y_valid, valid_y_pred
)
test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred)
test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred)
# AUPR
valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred)
test_aupr = metrics.average_precision_score(y_test, test_y_pred)
# Fmax
valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred)
valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max()
test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred)
test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max()
# F1
valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5)
test_f1 = f1_score(y_test, test_y_pred >= 0.5)
if __name__ == "__main__":
args = parse_config()
if torch.cuda.is_available():
print("cuda is available.")
print(f"current device {args}.")
else:
args.device = "cpu"
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
best_model_dir = (
f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/"
)
os.makedirs(best_model_dir)
args.save_name = best_model_dir
train(args)