Soumic
:lady_beetle: Repaired some major mistakes, but the model returns accuracy = 50%
d06b274
import logging
import os
import time
from typing import Any
from huggingface_hub import PyTorchModelHubMixin
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
import torch
from torch import nn
from datasets import load_dataset
timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO) # change to level=logging.DEBUG to print more logs...
NO_REGULARIZATION = 0
L1_REGULARIZATION_CODE = 1
L2_REGULARIZATION_CODE = 2
L1_AND_L2_REGULARIZATION_CODE = 3
black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"
FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"
DNA_BERT_6 = "zhihan1996/DNA_bert_6"
class CommonAttentionLayer(nn.Module):
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attention_linear = nn.Linear(hidden_size, 1)
pass
def forward(self, hidden_states):
# Apply linear layer
attn_weights = self.attention_linear(hidden_states)
# Apply softmax to get attention scores
attn_weights = torch.softmax(attn_weights, dim=1)
# Apply attention weights to hidden states
context_vector = torch.sum(attn_weights * hidden_states, dim=1)
return context_vector, attn_weights
class DNABert6MqtlClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self,
bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=DNA_BERT_6),
hidden_size=768, # I got mat-mul error, looks like this will be 12 times :/
num_classes=1,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.model_name = "DNABert6MqtlClassifier"
self.bert_model = bert_model
self.attention = CommonAttentionLayer(hidden_size) # Optional if you want to use attention
classifier_input_size = 8 # cz mat-mul error
self.classifier = nn.Linear(classifier_input_size, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids):
# Run BERT on each sub-sequence and collect the embeddings
embeddings = []
for i in range(input_ids.size(0)): # Iterate over sub-sequences
outputs = self.bert_model(
input_ids=input_ids[i],
attention_mask=attention_mask[i],
token_type_ids=token_type_ids[i] if token_type_ids is not None else None
)
last_hidden_state = outputs.last_hidden_state
embedding = last_hidden_state.mean(dim=1) # Example: taking the mean of hidden states
embeddings.append(embedding)
# Concatenate embeddings from all sub-sequences
concatenated_embedding = torch.cat(embeddings, dim=1)
# apply attention here
context_vector, _ = self.attention(concatenated_embedding)
# Classify
y_probability = self.classifier(context_vector)
return y_probability # float / double
class TorchMetrics:
def __init__(self):
self.binary_accuracy = BinaryAccuracy() #.to(device)
self.binary_auc = BinaryAUROC() # .to(device)
self.binary_f1_score = BinaryF1Score() # .to(device)
self.binary_precision = BinaryPrecision() # .to(device)
self.binary_recall = BinaryRecall() # .to(device)
pass
def update_on_each_step(self, batch_predicted_labels, batch_actual_labels): # todo: Add log if needed
# it looks like the library maintainers changed preds to input, ie, before: preds, now: input
self.binary_accuracy.update(input=batch_predicted_labels, target=batch_actual_labels)
self.binary_auc.update(input=batch_predicted_labels, target=batch_actual_labels)
self.binary_f1_score.update(input=batch_predicted_labels, target=batch_actual_labels)
self.binary_precision.update(input=batch_predicted_labels, target=batch_actual_labels)
self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
pass
def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green):
b_accuracy = self.binary_accuracy.compute()
b_auc = self.binary_auc.compute()
b_f1_score = self.binary_f1_score.compute()
b_precision = self.binary_precision.compute()
b_recall = self.binary_recall.compute()
timber.info(
log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
log(f"{log_prefix}_accuracy", b_accuracy)
log(f"{log_prefix}_auc", b_auc)
log(f"{log_prefix}_f1_score", b_f1_score)
log(f"{log_prefix}_precision", b_precision)
log(f"{log_prefix}_recall", b_recall)
pass
def reset_on_epoch_end(self):
self.binary_accuracy.reset()
self.binary_auc.reset()
self.binary_f1_score.reset()
self.binary_precision.reset()
self.binary_recall.reset()
class MQtlBertClassifierLightningModule(LightningModule):
def __init__(self,
classifier: nn.Module,
criterion=nn.BCEWithLogitsLoss(),
regularization: int = L2_REGULARIZATION_CODE,
# 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
l1_lambda=0.0001,
l2_wright_decay=0.0001,
*args: Any,
**kwargs: Any):
super().__init__(*args, **kwargs)
self.classifier = classifier
self.criterion = criterion
self.train_metrics = TorchMetrics()
self.validate_metrics = TorchMetrics()
self.test_metrics = TorchMetrics()
self.regularization = regularization
self.l1_lambda = l1_lambda
self.l2_weight_decay = l2_wright_decay
pass
def forward(self, input_ids, attention_mask, token_type_ids, *args: Any, **kwargs: Any) -> Any:
# print(f"\n{ type(input_ids) = }, {input_ids = }")
# print(f"{ type(attention_mask) = }, { attention_mask = }")
# print(f"{ type(token_type_ids) = }, { token_type_ids = }")
return self.classifier.forward(input_ids, attention_mask, token_type_ids)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Here we add weight decay (L2 regularization) to the optimizer
weight_decay = 0.0
if self.regularization == L2_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE:
weight_decay = self.l2_weight_decay
return torch.optim.Adam(self.parameters(), lr=1e-5, weight_decay=weight_decay) # , weight_decay=0.005)
def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on training batch data
input_ids, attention_mask, token_type_ids, y = batch
probability = self.forward(input_ids, attention_mask, token_type_ids)
# prediction
predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
loss = self.criterion(probability, y.float())
if self.regularization == L1_REGULARIZATION_CODE or self.regularization == L1_AND_L2_REGULARIZATION_CODE: # apply l1 regularization
l1_norm = sum(p.abs().sum() for p in self.parameters())
loss += self.l1_lambda * l1_norm
self.log("train_loss", loss)
# calculate the scores start
self.train_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
# self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
# calculate the scores end
return loss
def on_train_epoch_end(self) -> None:
self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
self.train_metrics.reset_on_epoch_end()
pass
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
# print(f"debug { batch = }")
input_ids, attention_mask, token_type_ids, y = batch
probability = self.forward(input_ids, attention_mask, token_type_ids)
# prediction
predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
# print(blue+f"{x.shape = }")
# x should have [32, sth...]
loss = self.criterion(probability, y.float())
""" loss = 0 # <------------------------- maybe the loss calculation is problematic """
self.log("valid_loss", loss)
# calculate the scores start
self.validate_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
# self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
# calculate the scores end
return loss
def on_validation_epoch_end(self) -> None:
self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
self.validate_metrics.reset_on_epoch_end()
return None
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
# Accuracy on validation batch data
input_ids, attention_mask, token_type_ids, y = batch
probability = self.forward(input_ids, attention_mask, token_type_ids)
# prediction
predicted_class = (probability >= 0.5).int() # Convert to binary and cast to int
loss = self.criterion(probability, y.float())
self.log("test_loss", loss) # do we need this?
# calculate the scores start
self.test_metrics.update_on_each_step(batch_predicted_labels=predicted_class, batch_actual_labels=y)
self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
# self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
# calculate the scores end
return loss
def on_test_epoch_end(self) -> None:
self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
self.test_metrics.reset_on_epoch_end()
return None
pass
class PagingMQTLDnaBertDataset(IterableDataset):
def __init__(self, dataset, tokenizer, max_length=512):
self.dataset = dataset
self.bert_tokenizer = tokenizer
self.max_length = max_length
def __iter__(self):
for row in self.dataset:
processed = self.preprocess(row)
if processed is not None:
yield processed
def preprocess(self, row):
sequence = row['sequence']
label = row['label']
# Split the sequence into chunks of size max_length (512)
chunks = [sequence[i:i + self.max_length] for i in range(0, len(sequence), self.max_length)]
# Tokenize each chunk and return the tokenized inputs
tokenized_inputs = {
'input_ids': [],
'attention_mask': [],
'token_type_ids': [] # If needed for DNABERT
}
for chunk in chunks:
encoded_chunk = self.bert_tokenizer(
chunk,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
tokenized_inputs['input_ids'].append(encoded_chunk['input_ids'].squeeze(0))
tokenized_inputs['attention_mask'].append(encoded_chunk['attention_mask'].squeeze(0))
tokenized_inputs['token_type_ids'].append(
encoded_chunk['token_type_ids'].squeeze(0) if 'token_type_ids' in encoded_chunk else None)
# Convert list of tensors to tensors with an extra batch dimension
tokenized_inputs = {k: torch.stack(v) for k, v in tokenized_inputs.items() if v[0] is not None}
input_ids = tokenized_inputs['input_ids']
attention_mask = tokenized_inputs['attention_mask']
token_type_ids = tokenized_inputs['token_type_ids']
# print(f"{type(input_ids) }")
# print(f"{type(attention_mask) }")
# print(f"{type(token_type_ids) }")
# Concatenate these tensors along a new dimension
# Result will be shape [3, num_chunks, 512]
# stacked_inputs = torch.stack([input_ids, attention_mask, token_type_ids], dim=0)
# return stacked_inputs, torch.tensor(label)
return input_ids, attention_mask, token_type_ids, torch.tensor(label).int()
class DNABERTDataModule(LightningDataModule):
def __init__(self, model_name=DNA_BERT_6, batch_size=8, WINDOW=-1, is_local=False):
super().__init__()
self.tokenized_dataset = None
self.dataset = None
self.train_dataset: PagingMQTLDnaBertDataset = None
self.validate_dataset: PagingMQTLDnaBertDataset = None
self.test_dataset: PagingMQTLDnaBertDataset = None
self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
self.batch_size = batch_size
self.is_local = is_local
self.window = WINDOW
def prepare_data(self):
# Download and prepare dataset
data_files = {
# small samples
"train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
"validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
"test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
# medium samples
"train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
"validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
"test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",
# large samples
"train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
"validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
"test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
# really tiny
# "tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_train_binned.csv",
# "tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_validate_binned.csv",
# "tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/tiny_dataset_4000_test_binned.csv",
"tiny_train": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_train_binned.csv",
"tiny_validate": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_validate_binned.csv",
"tiny_test": "/home/soumic/Codes/mqtl-classification/src/inputdata/medium_dataset_4000_test_binned.csv",
}
if self.is_local:
self.dataset = load_dataset("csv", data_files=data_files, streaming=True)
else:
self.dataset = load_dataset("fahimfarhan/mqtl-classification-datasets")
def setup(self, stage=None):
self.train_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer)
self.validate_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_validate'], self.tokenizer)
self.test_dataset = PagingMQTLDnaBertDataset(self.dataset['tiny_test'], self.tokenizer)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=1)
def val_dataloader(self):
return DataLoader(self.validate_dataset, batch_size=self.batch_size, num_workers=1)
def test_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=1)
def start_bert(classifier_model, model_save_path, criterion, WINDOW, batch_size=4,
is_binned=True, is_debug=False, max_epochs=10, regularization_code=L2_REGULARIZATION_CODE):
is_my_laptop = os.path.isfile("/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv")
model_local_directory = f"my-awesome-model-{WINDOW}"
model_remote_repository = f"fahimfarhan/dnabert-6-mqtl-classifier-{WINDOW}"
file_suffix = ""
if is_binned:
file_suffix = "_binned"
data_module = DNABERTDataModule(batch_size=batch_size, WINDOW=WINDOW, is_local=is_my_laptop)
# classifier_model = classifier_model.to(DEVICE)
classifier_module = MQtlBertClassifierLightningModule(
classifier=classifier_model,
regularization=regularization_code, criterion=criterion)
# if os.path.exists(model_save_path):
# classifier_module.load_state_dict(torch.load(model_save_path))
classifier_module = classifier_module # .double()
# Prepare data using the DataModule
data_module.prepare_data()
data_module.setup()
trainer = Trainer(max_epochs=max_epochs, precision="32")
# Train the model
trainer.fit(model=classifier_module, datamodule=data_module)
trainer.test(model=classifier_module, datamodule=data_module)
torch.save(classifier_module.state_dict(), model_save_path)
# classifier_module.push_to_hub("fahimfarhan/mqtl-classifier-model")
classifier_model.save_pretrained(save_directory=model_local_directory, safe_serialization=False)
# push to the hub
commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
if is_my_laptop:
commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"
classifier_model.push_to_hub(
repo_id=model_remote_repository,
# subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
commit_message=commit_message, # f":tada: Push model for window size {WINDOW}"
# safe_serialization=False
)
pass
if __name__ == "__main__":
start_time = time.time()
dataset_folder_prefix = "inputdata/"
pytorch_model = DNABert6MqtlClassifier()
start_bert(classifier_model=pytorch_model, model_save_path=f"weights_{pytorch_model.model_name}.pth",
criterion=nn.BCEWithLogitsLoss(), WINDOW=4000, batch_size=1, # 12, # max 14 on my laptop...
max_epochs=1, regularization_code=L2_REGULARIZATION_CODE)
# Record the end time
end_time = time.time()
# Calculate the duration
duration = end_time - start_time
# Print the runtime
print(f"Runtime: {duration:.2f} seconds")
pass