import argparse import os import random import string import sys import pandas as pd from datetime import datetime sys.path.append("../") import numpy as np import torch import lightgbm as lgb import sklearn.metrics as metrics from sklearn.utils import class_weight from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel from utils.downstream_disgenet import DisGeNETProcessor from utils.metric_learning_models import GDA_Metric_Learning def parse_config(): parser = argparse.ArgumentParser() parser.add_argument('-f') parser.add_argument("--step", type=int, default=0) parser.add_argument( "--save_model_path", type=str, default=None, help="path of the pretrained disease model located", ) parser.add_argument( "--prot_encoder_path", type=str, default="facebook/esm2_t33_650M_UR50D", #"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D” help="path/name of protein encoder model located", ) parser.add_argument( "--disease_encoder_path", type=str, default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", help="path/name of textual pre-trained language model", ) parser.add_argument("--reduction_factor", type=int, default=8) parser.add_argument( "--loss", help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}", default="infoNCE", ) parser.add_argument( "--input_feature_save_path", type=str, default="../../data/processed_disease", help="path of tokenized training data", ) parser.add_argument( "--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}" ) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--num_leaves", type=int, default=5) parser.add_argument("--max_depth", type=int, default=5) parser.add_argument("--lr", type=float, default=0.35) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--test", type=int, default=0) parser.add_argument("--use_miner", action="store_true") parser.add_argument("--miner_margin", default=0.2, type=float) parser.add_argument("--freeze_prot_encoder", action="store_true") parser.add_argument("--freeze_disease_encoder", action="store_true") parser.add_argument("--use_adapter", action="store_true") parser.add_argument("--use_pooled", action="store_true") parser.add_argument("--device", type=str, default="cpu") parser.add_argument( "--use_both_feature", help="use the both features of gnn_feature_v1_samples and pretrained models", action="store_true", ) parser.add_argument( "--use_v1_feature_only", help="use the features of gnn_feature_v1_samples only", action="store_true", ) parser.add_argument( "--save_path_prefix", type=str, default="../../save_model_ckp/finetune/", help="save the result in which directory", ) parser.add_argument( "--save_name", default="fine_tune", type=str, help="the name of the saved file" ) # Add argument for input CSV file path parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.") # Add argument for output CSV file path parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.") return parser.parse_args() def get_feature(model, dataloader, args): x = list() y = list() with torch.no_grad(): for step, batch in tqdm(enumerate(dataloader)): prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch prot_input = { 'input_ids': prot_input_ids.to(args.device), 'attention_mask': prot_attention_mask.to(args.device) } dis_input = { 'input_ids': dis_input_ids.to(args.device), 'attention_mask': dis_attention_mask.to(args.device) } feature_output = model.predict(prot_input, dis_input) x1 = feature_output.cpu().numpy() x.append(x1) y.append(y1.cpu().numpy()) x = np.concatenate(x, axis=0) y = np.concatenate(y, axis=0) return x, y def encode_pretrained_feature(args, disGeNET): input_feat_file = os.path.join( args.input_feature_save_path, f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz", ) if os.path.exists(input_feat_file): print(f"load prior feature data from {input_feat_file}.") loaded = np.load(input_feat_file) x_train, y_train = loaded["x_train"], loaded["y_train"] x_valid, y_valid = loaded["x_valid"], loaded["y_valid"] # x_test, y_test = loaded["x_test"], loaded["y_test"] prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False) # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False) print("prot_tokenizer", len(prot_tokenizer)) disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path) print("disease_tokenizer", len(disease_tokenizer)) prot_model = EsmModel.from_pretrained(args.prot_encoder_path) # prot_model = BertModel.from_pretrained(args.prot_encoder_path) disease_model = BertModel.from_pretrained(args.disease_encoder_path) if args.save_model_path: model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args) if args.use_adapter: prot_model_path = os.path.join( args.save_model_path, f"prot_adapter_step_{args.step}" ) disease_model_path = os.path.join( args.save_model_path, f"disease_adapter_step_{args.step}" ) model.load_adapters(prot_model_path, disease_model_path) else: prot_model_path = os.path.join( args.save_model_path, f"step_{args.step}_model.bin" )# , f"step_{args.step}_model.bin" disease_model_path = os.path.join( args.save_model_path, f"step_{args.step}_model.bin" ) model.non_adapters(prot_model_path, disease_model_path) model = model.to(args.device) prot_model = model.prot_encoder disease_model = model.disease_encoder print(f"loaded prior model {args.save_model_path}.") def collate_fn_batch_encoding(batch): query1, query2, scores = zip(*batch) query_encodings1 = prot_tokenizer.batch_encode_plus( list(query1), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) query_encodings2 = disease_tokenizer.batch_encode_plus( list(query2), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) scores = torch.tensor(list(scores)) attention_mask1 = query_encodings1["attention_mask"].bool() attention_mask2 = query_encodings2["attention_mask"].bool() return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores test_examples = disGeNET.get_test_examples(args.test) print(f"get test examples: {len(test_examples)}") test_dataloader = DataLoader( test_examples, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn_batch_encoding, ) print( f"dataset loaded: test-{len(test_examples)}") x_test, y_test = get_feature(model, test_dataloader, args) else: prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False) # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False) print("prot_tokenizer", len(prot_tokenizer)) disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path) print("disease_tokenizer", len(disease_tokenizer)) prot_model = EsmModel.from_pretrained(args.prot_encoder_path) # prot_model = BertModel.from_pretrained(args.prot_encoder_path) disease_model = BertModel.from_pretrained(args.disease_encoder_path) if args.save_model_path: model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args) if args.use_adapter: prot_model_path = os.path.join( args.save_model_path, f"prot_adapter_step_{args.step}" ) disease_model_path = os.path.join( args.save_model_path, f"disease_adapter_step_{args.step}" ) model.load_adapters(prot_model_path, disease_model_path) else: prot_model_path = os.path.join( args.save_model_path, f"step_{args.step}_model.bin" )# , f"step_{args.step}_model.bin" disease_model_path = os.path.join( args.save_model_path, f"step_{args.step}_model.bin" ) model.non_adapters(prot_model_path, disease_model_path) model = model.to(args.device) prot_model = model.prot_encoder disease_model = model.disease_encoder print(f"loaded prior model {args.save_model_path}.") def collate_fn_batch_encoding(batch): query1, query2, scores = zip(*batch) query_encodings1 = prot_tokenizer.batch_encode_plus( list(query1), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) query_encodings2 = disease_tokenizer.batch_encode_plus( list(query2), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) scores = torch.tensor(list(scores)) attention_mask1 = query_encodings1["attention_mask"].bool() attention_mask2 = query_encodings2["attention_mask"].bool() return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores train_examples = disGeNET.get_train_examples(args.test) print(f"get training examples: {len(train_examples)}") valid_examples = disGeNET.get_val_examples(args.test) print(f"get validation examples: {len(valid_examples)}") test_examples = disGeNET.get_test_examples(args.test) print(f"get test examples: {len(test_examples)}") train_dataloader = DataLoader( train_examples, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn_batch_encoding, ) valid_dataloader = DataLoader( valid_examples, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn_batch_encoding, ) test_dataloader = DataLoader( test_examples, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn_batch_encoding, ) print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}") x_train, y_train = get_feature(model, train_dataloader, args) x_valid, y_valid = get_feature(model, valid_dataloader, args) x_test, y_test = get_feature(model, test_dataloader, args) # Save input feature to reduce encoding time np.savez_compressed( input_feat_file, x_train=x_train, y_train=y_train, x_valid=x_valid, y_valid=y_valid, ) print(f"save input feature into {input_feat_file}") # Save input feature to reduce encoding time return x_train, y_train, x_valid, y_valid, x_test, y_test def train(args): # defining parameters if args.save_model_path: args.model_short = ( args.save_model_path.split("/")[-1] ) print(f"model name {args.model_short}") else: args.model_short = ( args.disease_encoder_path.split("/")[-1] ) print(f"model name {args.model_short}") # disGeNET = DisGeNETProcessor() disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path) x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET) print("train: ", x_train.shape, y_train.shape) print("valid: ", x_valid.shape, y_valid.shape) print("test: ", x_test.shape, y_test.shape) params = { "task": "train", # "predict" train "boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"." "objective": "binary", "num_leaves": args.num_leaves, "early_stopping_round": 30, "max_depth": args.max_depth, "learning_rate": args.lr, "metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc" "verbose": 1, } lgb_train = lgb.Dataset(x_train, y_train) lgb_valid = lgb.Dataset(x_valid, y_valid) lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train) # fitting the model model = lgb.train( params, train_set=lgb_train, valid_sets=lgb_valid) # prediction valid_y_pred = model.predict(x_valid) test_y_pred = model.predict(x_test) # predict liver fibrosis predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"]) # data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv') data_test = pd.read_csv(args.input_csv_path) predictions = pd.concat([data_test, predictions_df], axis=1) # filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714'] predictions.sort_values(by='Prediction_score', ascending=False, inplace=True) top_100_predictions = predictions.head(100) top_100_predictions.to_csv(args.output_csv_path, index=False) # Accuracy y_pred = model.predict(x_test, num_iteration=model.best_iteration) y_pred[y_pred >= 0.5] = 1 y_pred[y_pred < 0.5] = 0 accuracy = accuracy_score(y_test, y_pred) # AUC valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred) valid_average_precision_score = metrics.average_precision_score( y_valid, valid_y_pred ) test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred) test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred) # AUPR valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred) test_aupr = metrics.average_precision_score(y_test, test_y_pred) # Fmax valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred) valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max() test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred) test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max() # F1 valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5) test_f1 = f1_score(y_test, test_y_pred >= 0.5) if __name__ == "__main__": args = parse_config() if torch.cuda.is_available(): print("cuda is available.") print(f"current device {args}.") else: args.device = "cpu" timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)]) best_model_dir = ( f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/" ) os.makedirs(best_model_dir) args.save_name = best_model_dir train(args)