import logging import os import time from typing import Any from huggingface_hub import PyTorchModelHubMixin from pytorch_lightning import Trainer, LightningModule, LightningDataModule from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS from torch.utils.data import DataLoader, Dataset, IterableDataset from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions import torch from torch import nn from datasets import load_dataset timber = logging.getLogger() # logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs... NO_REGULARIZATION = 0 L1_REGULARIZATION_CODE = 1 L2_REGULARIZATION_CODE = 2 L1_AND_L2_REGULARIZATION_CODE = 3 black = "\u001b[30m" red = "\u001b[31m" green = "\u001b[32m" yellow = "\u001b[33m" blue = "\u001b[34m" magenta = "\u001b[35m" cyan = "\u001b[36m" white = "\u001b[37m" FORWARD = "FORWARD_INPUT" BACKWARD = "BACKWARD_INPUT" DNA_BERT_6 = "zhihan1996/DNA_bert_6" class CommonAttentionLayer(nn.Module): def __init__(self, hidden_size, *args, **kwargs): super().__init__(*args, **kwargs) self.attention_linear = nn.Linear(hidden_size, 1) pass def forward(self, hidden_states): # Apply linear layer attn_weights = self.attention_linear(hidden_states) # Apply softmax to get attention scores attn_weights = torch.softmax(attn_weights, dim=1) # Apply attention weights to hidden states context_vector = torch.sum(attn_weights * hidden_states, dim=1) return context_vector, attn_weights class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss): def forward(self, input, target): return super().forward(input.squeeze(), target.float()) class MQtlDnaBERT6Classifier(nn.Module, PyTorchModelHubMixin): def __init__(self, bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6), hidden_size=768, num_classes=1, *args, **kwargs ): super().__init__(*args, **kwargs) self.model_name = "MQtlDnaBERT6Classifier" self.bert_model = bert_model self.attention = CommonAttentionLayer(hidden_size) self.classifier = nn.Linear(hidden_size, num_classes) pass def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids): """ # torch.Size([128, 1, 512]) --> [128, 512] input_ids = input_ids.squeeze(dim=1).to(DEVICE) # torch.Size([16, 1, 512]) --> [16, 512] attention_mask = attention_mask.squeeze(dim=1).to(DEVICE) token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE) """ bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) last_hidden_state = bert_output.last_hidden_state context_vector, ignore_attention_weight = self.attention(last_hidden_state) y = self.classifier(context_vector) return y """ class TorchMetrics: def __init__(self): self.binary_accuracy = BinaryAccuracy() #.to(device) self.binary_auc = BinaryAUROC() # .to(device) self.binary_f1_score = BinaryF1Score() # .to(device) self.binary_precision = BinaryPrecision() # .to(device) self.binary_recall = BinaryRecall() # .to(device) pass def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed # it looks like the library maintainers changed preds to input, ie, before: preds, now: input self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels) pass def compute_and_log_on_each_step(self, log, log_prefix: str, log_color: str = green): b_accuracy = self.binary_accuracy.compute() b_auc = self.binary_auc.compute() b_f1_score = self.binary_f1_score.compute() b_precision = self.binary_precision.compute() b_recall = self.binary_recall.compute() timber.info(log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}") log(f"{log_prefix}_accuracy", b_accuracy) log(f"{log_prefix}_auc", b_auc) log(f"{log_prefix}_f1_score", b_f1_score) log(f"{log_prefix}_precision", b_precision) log(f"{log_prefix}_recall", b_recall) # def reset_on_epoch_end(self): # self.binary_accuracy.reset() # self.binary_auc.reset() # self.binary_f1_score.reset() # self.binary_precision.reset() # self.binary_recall.reset() def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green): b_accuracy = self.binary_accuracy.compute() b_auc = self.binary_auc.compute() b_f1_score = self.binary_f1_score.compute() b_precision = self.binary_precision.compute() b_recall = self.binary_recall.compute() timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}") log(f"{log_prefix}_accuracy", b_accuracy) log(f"{log_prefix}_auc", b_auc) log(f"{log_prefix}_f1_score", b_f1_score) log(f"{log_prefix}_precision", b_precision) log(f"{log_prefix}_recall", b_recall) self.binary_accuracy.reset() self.binary_auc.reset() self.binary_f1_score.reset() self.binary_precision.reset() self.binary_recall.reset() pass """ class TorchMetrics: def __init__(self): self.binary_accuracy = BinaryAccuracy() #.to(device) self.binary_auc = BinaryAUROC() # .to(device) self.binary_f1_score = BinaryF1Score() # .to(device) self.binary_precision = BinaryPrecision() # .to(device) self.binary_recall = BinaryRecall() # .to(device) pass def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed # it looks like the library maintainers changed preds to input, ie, before: preds, now: input self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels) self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels) pass def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green): b_accuracy = self.binary_accuracy.compute() b_auc = self.binary_auc.compute() b_f1_score = self.binary_f1_score.compute() b_precision = self.binary_precision.compute() b_recall = self.binary_recall.compute() timber.info( log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}") log(f"{log_prefix}_accuracy", b_accuracy) log(f"{log_prefix}_auc", b_auc) log(f"{log_prefix}_f1_score", b_f1_score) log(f"{log_prefix}_precision", b_precision) log(f"{log_prefix}_recall", b_recall) pass def reset_on_epoch_end(self): self.binary_accuracy.reset() self.binary_auc.reset() self.binary_f1_score.reset() self.binary_precision.reset() self.binary_recall.reset() class MQtlBertClassifierLightningModule(LightningModule): def __init__(self, classifier: nn.Module, criterion=None, # nn.BCEWithLogitsLoss(), regularization: int = L2_REGULARIZATION_CODE, # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care l1_lambda=0.0001, l2_wright_decay=0.0001, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.classifier = classifier self.criterion = criterion self.train_metrics = TorchMetrics() self.validate_metrics = TorchMetrics() self.test_metrics = TorchMetrics() self.regularization = regularization self.l1_lambda = l1_lambda self.l2_weight_decay = l2_wright_decay pass def forward(self, x, *args: Any, **kwargs: Any) -> Any: input_ids: torch.tensor = x["input_ids"] attention_mask: torch.tensor = x["attention_mask"] token_type_ids: torch.tensor = x["token_type_ids"] # print(f"\n{ type(input_ids) = }, {input_ids = }") # print(f"{ type(attention_mask) = }, { attention_mask = }") # print(f"{ type(token_type_ids) = }, { token_type_ids = }") return self.classifier.forward(input_ids, attention_mask, token_type_ids) def configure_optimizers(self) -> OptimizerLRScheduler: # Here we add weight decay (L2 regularization) to the optimizer weight_decay = 0.0 if self.regularization == 2 or self.regularization == 3: weight_decay = self.l2_weight_decay return torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=weight_decay) # , weight_decay=0.005) def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: # Accuracy on training batch data x, y = batch preds = self.forward(x) loss = self.criterion(preds, y) if self.regularization == 1 or self.regularization == 3: # apply l1 regularization l1_norm = sum(p.abs().sum() for p in self.parameters()) loss += self.l1_lambda * l1_norm self.log("train_loss", loss) # calculate the scores start self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train") # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train") # calculate the scores end return loss def on_train_epoch_end(self) -> None: self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train") self.train_metrics.reset_on_epoch_end() pass def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: # Accuracy on validation batch data # print(f"debug { batch = }") x, y = batch preds = self.forward(x) loss = self.criterion(preds, y) """ loss = 0 # <------------------------- maybe the loss calculation is problematic """ self.log("valid_loss", loss) # calculate the scores start self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue) # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue) # calculate the scores end return loss def on_validation_epoch_end(self) -> None: self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue) self.validate_metrics.reset_on_epoch_end() return None def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT: # Accuracy on validation batch data x, y = batch preds = self.forward(x) loss = self.criterion(preds, y) self.log("test_loss", loss) # do we need this? # calculate the scores start self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y) self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta) # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta) # calculate the scores end return loss def on_test_epoch_end(self) -> None: self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta) self.test_metrics.reset_on_epoch_end() return None pass class PagingMQTLDnaBertDataset(IterableDataset): def __init__(self, dataset, tokenizer, max_length=512): # hold on! why is it 512? I added 4000, and it crashed, the error suggested 512, that's why 512 self.dataset = dataset self.bert_tokenizer = tokenizer self.max_length = max_length # def __len__(self): # return len(self.dataset) def __iter__(self): for row in self.dataset: processed = self.preprocess(row) if processed is not None: yield processed def preprocess(self, row): sequence = row['sequence'] # Fetch the 'sequence' column label = row['label'] # Fetch the 'label' column (or whatever target you use) # Tokenize the sequence encoded_sequence: BatchEncoding = self.bert_tokenizer( sequence, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()} return encoded_sequence_squeezed, label class DNABERTDataModule(LightningDataModule): def __init__(self, model_name=DNA_BERT_6, batch_size=8, WINDOW=-1, is_local=False): super().__init__() self.tokenized_dataset = None self.dataset = None self.train_dataset: PagingMQTLDnaBertDataset = None self.validate_dataset: PagingMQTLDnaBertDataset = None self.test_dataset: PagingMQTLDnaBertDataset = None self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name) self.batch_size = batch_size self.is_local = is_local self.window = WINDOW def prepare_data(self): # Download and prepare dataset data_files = { # small samples "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv", "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv", "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv", # medium samples "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv", "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv", "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv", # large samples "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv", "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv", "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv", # really tiny # "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_train_binned.csv", # "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv", # "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv", "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_train_binned.csv", "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_validate_binned.csv", "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_test_binned.csv", } if self.is_local: self.dataset = load_dataset("csv", data_files=data_files, streaming=True) else: self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets") def setup(self, stage=None): self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_train'], self.tokenizer) self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_validate'], self.tokenizer) self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1) def val_dataloader(self): return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=1) def test_dataloader(self) -> EVAL_DATALOADERS: return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=1) def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4, is_binned=True, is_debug=False, max_epochs=10, regularization_code = L2_REGULARIZATION_CODE): is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv") model_local_directory = f"my-awesome-model-{WINDOW}" model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}" file_suffix = "" if is_binned: file_suffix = "_binned" data_module = DNABERTDataModule(batch_size=batch_size, WINDOW=WINDOW, is_local=is_my_laptop) # classifier_model = classifier_model.to(DEVICE) classifier_module = MQtlBertClassifierLightningModule( classifier=classifier_model, regularization=regularization_code, criterion=criterion) # if os.path.exists(model_save_path): # classifier_module.load_state_dict(torch.load(model_save_path)) classifier_module = classifier_module # .double() # Prepare data using the DataModule data_module.prepare_data() data_module.setup() trainer = Trainer(max_epochs=max_epochs, precision="32") # Train the model trainer.fit(model=classifier_module, datamodule=data_module) trainer.test(model=classifier_module, datamodule=data_module) torch.save(classifier_module.state_dict(), model_save_path) # classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model") classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False) # push to the hub commit_message = f":tada: Push model for window size {WINDOW} from huggingface space" if is_my_laptop: commit_message = f":tada: Push model for window size {WINDOW} from zephyrus" classifier_model.push_to_hub( repo_id=model_remote_repository, # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/ commit_message=commit_message, # f":tada: Push model for window size {WINDOW}" # safe_serialization=False ) pass if __name__ == "__main__": start_time = time.time() dataset_folder_prefix = "inputdata/" pytorch_model = MQtlDnaBERT6Classifier() start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth", criterion=ReshapedBCEWithLogitsLoss(), WINDOW=4000, batch_size=12, # max 14 on my laptop... max_epochs=1, regularization_code=L2_REGULARIZATION_CODE) # Record the end time end_time = time.time() # Calculate the duration duration = end_time - start_time # Print the runtime print(f"Runtime: {duration:.2f} seconds") pass