Gla-AI4BioMed-Lab commited on
Commit
90fac5b
1 Parent(s): c54bfd4

Delete src/finetune/.ipynb_checkpoints

Browse files
src/finetune/.ipynb_checkpoints/finetune-checkpoint.py DELETED
@@ -1,416 +0,0 @@
1
- import argparse
2
- import os
3
- import random
4
- import string
5
- import sys
6
- import pandas as pd
7
- from datetime import datetime
8
-
9
- sys.path.append("../")
10
- import numpy as np
11
- import torch
12
- import lightgbm as lgb
13
- import sklearn.metrics as metrics
14
- from sklearn.utils import class_weight
15
- from sklearn.model_selection import train_test_split
16
- from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score
17
- from torch.utils.data import DataLoader
18
- from tqdm.auto import tqdm
19
- from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel
20
- from utils.downstream_disgenet import DisGeNETProcessor
21
- from utils.metric_learning_models import GDA_Metric_Learning
22
-
23
- def parse_config():
24
- parser = argparse.ArgumentParser()
25
- parser.add_argument('-f')
26
- parser.add_argument("--step", type=int, default=0)
27
- parser.add_argument(
28
- "--save_model_path",
29
- type=str,
30
- default=None,
31
- help="path of the pretrained disease model located",
32
- )
33
- parser.add_argument(
34
- "--prot_encoder_path",
35
- type=str,
36
- default="facebook/esm2_t33_650M_UR50D",
37
- #"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D”
38
- help="path/name of protein encoder model located",
39
- )
40
- parser.add_argument(
41
- "--disease_encoder_path",
42
- type=str,
43
- default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
44
- help="path/name of textual pre-trained language model",
45
- )
46
- parser.add_argument("--reduction_factor", type=int, default=8)
47
- parser.add_argument(
48
- "--loss",
49
- help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}",
50
- default="infoNCE",
51
- )
52
- parser.add_argument(
53
- "--input_feature_save_path",
54
- type=str,
55
- default="../../data/processed_disease",
56
- help="path of tokenized training data",
57
- )
58
- parser.add_argument(
59
- "--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}"
60
- )
61
- parser.add_argument("--batch_size", type=int, default=256)
62
- parser.add_argument("--patience", type=int, default=5)
63
- parser.add_argument("--num_leaves", type=int, default=5)
64
- parser.add_argument("--max_depth", type=int, default=5)
65
- parser.add_argument("--lr", type=float, default=0.35)
66
- parser.add_argument("--dropout", type=float, default=0.1)
67
- parser.add_argument("--test", type=int, default=0)
68
- parser.add_argument("--use_miner", action="store_true")
69
- parser.add_argument("--miner_margin", default=0.2, type=float)
70
- parser.add_argument("--freeze_prot_encoder", action="store_true")
71
- parser.add_argument("--freeze_disease_encoder", action="store_true")
72
- parser.add_argument("--use_adapter", action="store_true")
73
- parser.add_argument("--use_pooled", action="store_true")
74
- parser.add_argument("--device", type=str, default="cpu")
75
- parser.add_argument(
76
- "--use_both_feature",
77
- help="use the both features of gnn_feature_v1_samples and pretrained models",
78
- action="store_true",
79
- )
80
- parser.add_argument(
81
- "--use_v1_feature_only",
82
- help="use the features of gnn_feature_v1_samples only",
83
- action="store_true",
84
- )
85
- parser.add_argument(
86
- "--save_path_prefix",
87
- type=str,
88
- default="../../save_model_ckp/finetune/",
89
- help="save the result in which directory",
90
- )
91
- parser.add_argument(
92
- "--save_name", default="fine_tune", type=str, help="the name of the saved file"
93
- )
94
- # Add argument for input CSV file path
95
- parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.")
96
-
97
- # Add argument for output CSV file path
98
- parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.")
99
- return parser.parse_args()
100
-
101
- def get_feature(model, dataloader, args):
102
- x = list()
103
- y = list()
104
- with torch.no_grad():
105
- for step, batch in tqdm(enumerate(dataloader)):
106
- prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch
107
- prot_input = {
108
- 'input_ids': prot_input_ids.to(args.device),
109
- 'attention_mask': prot_attention_mask.to(args.device)
110
- }
111
- dis_input = {
112
- 'input_ids': dis_input_ids.to(args.device),
113
- 'attention_mask': dis_attention_mask.to(args.device)
114
- }
115
- feature_output = model.predict(prot_input, dis_input)
116
- x1 = feature_output.cpu().numpy()
117
- x.append(x1)
118
- y.append(y1.cpu().numpy())
119
- x = np.concatenate(x, axis=0)
120
- y = np.concatenate(y, axis=0)
121
- return x, y
122
-
123
-
124
- def encode_pretrained_feature(args, disGeNET):
125
- input_feat_file = os.path.join(
126
- args.input_feature_save_path,
127
- f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz",
128
- )
129
-
130
- if os.path.exists(input_feat_file):
131
- print(f"load prior feature data from {input_feat_file}.")
132
- loaded = np.load(input_feat_file)
133
- x_train, y_train = loaded["x_train"], loaded["y_train"]
134
- x_valid, y_valid = loaded["x_valid"], loaded["y_valid"]
135
- # x_test, y_test = loaded["x_test"], loaded["y_test"]
136
-
137
- prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
138
- # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
139
- print("prot_tokenizer", len(prot_tokenizer))
140
- disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
141
- print("disease_tokenizer", len(disease_tokenizer))
142
-
143
- prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
144
- # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
145
- disease_model = BertModel.from_pretrained(args.disease_encoder_path)
146
-
147
- if args.save_model_path:
148
- model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
149
-
150
- if args.use_adapter:
151
- prot_model_path = os.path.join(
152
- args.save_model_path, f"prot_adapter_step_{args.step}"
153
- )
154
- disease_model_path = os.path.join(
155
- args.save_model_path, f"disease_adapter_step_{args.step}"
156
- )
157
- model.load_adapters(prot_model_path, disease_model_path)
158
- else:
159
- prot_model_path = os.path.join(
160
- args.save_model_path, f"step_{args.step}_model.bin"
161
- )# , f"step_{args.step}_model.bin"
162
- disease_model_path = os.path.join(
163
- args.save_model_path, f"step_{args.step}_model.bin"
164
- )
165
- model.non_adapters(prot_model_path, disease_model_path)
166
-
167
- model = model.to(args.device)
168
- prot_model = model.prot_encoder
169
- disease_model = model.disease_encoder
170
- print(f"loaded prior model {args.save_model_path}.")
171
-
172
- def collate_fn_batch_encoding(batch):
173
- query1, query2, scores = zip(*batch)
174
-
175
- query_encodings1 = prot_tokenizer.batch_encode_plus(
176
- list(query1),
177
- max_length=512,
178
- padding="max_length",
179
- truncation=True,
180
- add_special_tokens=True,
181
- return_tensors="pt",
182
- )
183
- query_encodings2 = disease_tokenizer.batch_encode_plus(
184
- list(query2),
185
- max_length=512,
186
- padding="max_length",
187
- truncation=True,
188
- add_special_tokens=True,
189
- return_tensors="pt",
190
- )
191
- scores = torch.tensor(list(scores))
192
- attention_mask1 = query_encodings1["attention_mask"].bool()
193
- attention_mask2 = query_encodings2["attention_mask"].bool()
194
-
195
- return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
196
-
197
- test_examples = disGeNET.get_test_examples(args.test)
198
- print(f"get test examples: {len(test_examples)}")
199
-
200
- test_dataloader = DataLoader(
201
- test_examples,
202
- batch_size=args.batch_size,
203
- shuffle=False,
204
- collate_fn=collate_fn_batch_encoding,
205
- )
206
- print( f"dataset loaded: test-{len(test_examples)}")
207
-
208
- x_test, y_test = get_feature(model, test_dataloader, args)
209
-
210
- else:
211
- prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
212
- # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
213
- print("prot_tokenizer", len(prot_tokenizer))
214
- disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
215
- print("disease_tokenizer", len(disease_tokenizer))
216
-
217
- prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
218
- # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
219
- disease_model = BertModel.from_pretrained(args.disease_encoder_path)
220
-
221
- if args.save_model_path:
222
- model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
223
-
224
- if args.use_adapter:
225
- prot_model_path = os.path.join(
226
- args.save_model_path, f"prot_adapter_step_{args.step}"
227
- )
228
- disease_model_path = os.path.join(
229
- args.save_model_path, f"disease_adapter_step_{args.step}"
230
- )
231
- model.load_adapters(prot_model_path, disease_model_path)
232
- else:
233
- prot_model_path = os.path.join(
234
- args.save_model_path, f"step_{args.step}_model.bin"
235
- )# , f"step_{args.step}_model.bin"
236
- disease_model_path = os.path.join(
237
- args.save_model_path, f"step_{args.step}_model.bin"
238
- )
239
- model.non_adapters(prot_model_path, disease_model_path)
240
-
241
- model = model.to(args.device)
242
- prot_model = model.prot_encoder
243
- disease_model = model.disease_encoder
244
- print(f"loaded prior model {args.save_model_path}.")
245
-
246
- def collate_fn_batch_encoding(batch):
247
- query1, query2, scores = zip(*batch)
248
-
249
- query_encodings1 = prot_tokenizer.batch_encode_plus(
250
- list(query1),
251
- max_length=512,
252
- padding="max_length",
253
- truncation=True,
254
- add_special_tokens=True,
255
- return_tensors="pt",
256
- )
257
- query_encodings2 = disease_tokenizer.batch_encode_plus(
258
- list(query2),
259
- max_length=512,
260
- padding="max_length",
261
- truncation=True,
262
- add_special_tokens=True,
263
- return_tensors="pt",
264
- )
265
- scores = torch.tensor(list(scores))
266
- attention_mask1 = query_encodings1["attention_mask"].bool()
267
- attention_mask2 = query_encodings2["attention_mask"].bool()
268
-
269
- return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
270
-
271
- train_examples = disGeNET.get_train_examples(args.test)
272
- print(f"get training examples: {len(train_examples)}")
273
- valid_examples = disGeNET.get_val_examples(args.test)
274
- print(f"get validation examples: {len(valid_examples)}")
275
- test_examples = disGeNET.get_test_examples(args.test)
276
- print(f"get test examples: {len(test_examples)}")
277
-
278
- train_dataloader = DataLoader(
279
- train_examples,
280
- batch_size=args.batch_size,
281
- shuffle=False,
282
- collate_fn=collate_fn_batch_encoding,
283
- )
284
- valid_dataloader = DataLoader(
285
- valid_examples,
286
- batch_size=args.batch_size,
287
- shuffle=False,
288
- collate_fn=collate_fn_batch_encoding,
289
- )
290
- test_dataloader = DataLoader(
291
- test_examples,
292
- batch_size=args.batch_size,
293
- shuffle=False,
294
- collate_fn=collate_fn_batch_encoding,
295
- )
296
- print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}")
297
-
298
- x_train, y_train = get_feature(model, train_dataloader, args)
299
- x_valid, y_valid = get_feature(model, valid_dataloader, args)
300
- x_test, y_test = get_feature(model, test_dataloader, args)
301
-
302
- # Save input feature to reduce encoding time
303
- np.savez_compressed(
304
- input_feat_file,
305
- x_train=x_train,
306
- y_train=y_train,
307
- x_valid=x_valid,
308
- y_valid=y_valid,
309
- )
310
- print(f"save input feature into {input_feat_file}")
311
- # Save input feature to reduce encoding time
312
- return x_train, y_train, x_valid, y_valid, x_test, y_test
313
-
314
-
315
- def train(args):
316
- # defining parameters
317
- if args.save_model_path:
318
- args.model_short = (
319
- args.save_model_path.split("/")[-1]
320
- )
321
- print(f"model name {args.model_short}")
322
-
323
- else:
324
- args.model_short = (
325
- args.disease_encoder_path.split("/")[-1]
326
- )
327
- print(f"model name {args.model_short}")
328
-
329
- # disGeNET = DisGeNETProcessor()
330
- disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path)
331
-
332
-
333
- x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET)
334
-
335
- print("train: ", x_train.shape, y_train.shape)
336
- print("valid: ", x_valid.shape, y_valid.shape)
337
- print("test: ", x_test.shape, y_test.shape)
338
-
339
- params = {
340
- "task": "train", # "predict" train
341
- "boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"."
342
- "objective": "binary",
343
- "num_leaves": args.num_leaves,
344
- "early_stopping_round": 30,
345
- "max_depth": args.max_depth,
346
- "learning_rate": args.lr,
347
- "metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc"
348
- "verbose": 1,
349
- }
350
-
351
- lgb_train = lgb.Dataset(x_train, y_train)
352
- lgb_valid = lgb.Dataset(x_valid, y_valid)
353
- lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
354
-
355
- # fitting the model
356
- model = lgb.train(
357
- params, train_set=lgb_train, valid_sets=lgb_valid)
358
-
359
- # prediction
360
- valid_y_pred = model.predict(x_valid)
361
- test_y_pred = model.predict(x_test)
362
-
363
- # predict liver fibrosis
364
- predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"])
365
- # data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv')
366
- data_test = pd.read_csv(args.input_csv_path)
367
- predictions = pd.concat([data_test, predictions_df], axis=1)
368
- # filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714']
369
- predictions.sort_values(by='Prediction_score', ascending=False, inplace=True)
370
- top_100_predictions = predictions.head(100)
371
- top_100_predictions.to_csv(args.output_csv_path, index=False)
372
-
373
- # Accuracy
374
- y_pred = model.predict(x_test, num_iteration=model.best_iteration)
375
- y_pred[y_pred >= 0.5] = 1
376
- y_pred[y_pred < 0.5] = 0
377
- accuracy = accuracy_score(y_test, y_pred)
378
-
379
- # AUC
380
- valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred)
381
- valid_average_precision_score = metrics.average_precision_score(
382
- y_valid, valid_y_pred
383
- )
384
- test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred)
385
- test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred)
386
-
387
- # AUPR
388
- valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred)
389
- test_aupr = metrics.average_precision_score(y_test, test_y_pred)
390
-
391
- # Fmax
392
- valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred)
393
- valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max()
394
- test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred)
395
- test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max()
396
-
397
- # F1
398
- valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5)
399
- test_f1 = f1_score(y_test, test_y_pred >= 0.5)
400
-
401
-
402
- if __name__ == "__main__":
403
- args = parse_config()
404
- if torch.cuda.is_available():
405
- print("cuda is available.")
406
- print(f"current device {args}.")
407
- else:
408
- args.device = "cpu"
409
- timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
410
- random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
411
- best_model_dir = (
412
- f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/"
413
- )
414
- os.makedirs(best_model_dir)
415
- args.save_name = best_model_dir
416
- train(args)