|
--- |
|
license: mit |
|
--- |
|
|
|
## Metrics: |
|
|
|
```python |
|
Train: |
|
({'accuracy': 0.9406146072672105, |
|
'precision': 0.2947122459102886, |
|
'recall': 0.952624323712029, |
|
'f1': 0.4501592605994876, |
|
'auc': 0.9464622170085311, |
|
'mcc': 0.5118390407598565}, |
|
Test: |
|
{'accuracy': 0.9266827008067329, |
|
'precision': 0.22378953253253775, |
|
'recall': 0.7790246675002842, |
|
'f1': 0.3476966444342296, |
|
'auc': 0.8547531675185658, |
|
'mcc': 0.3930283737012391}) |
|
``` |
|
|
|
## Using the Model |
|
|
|
Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) |
|
to download the dataset first. Once you have the pickle files downloaded locally, run the following: |
|
|
|
```python |
|
from datasets import Dataset |
|
from transformers import AutoTokenizer |
|
import pickle |
|
|
|
# Load tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") |
|
|
|
# Function to truncate labels |
|
def truncate_labels(labels, max_length): |
|
"""Truncate labels to the specified max_length.""" |
|
return [label[:max_length] for label in labels] |
|
|
|
# Set the maximum sequence length |
|
max_sequence_length = 1000 |
|
|
|
# Load the data from pickle files |
|
with open("train_sequences_chunked_by_family.pkl", "rb") as f: |
|
train_sequences = pickle.load(f) |
|
with open("test_sequences_chunked_by_family.pkl", "rb") as f: |
|
test_sequences = pickle.load(f) |
|
with open("train_labels_chunked_by_family.pkl", "rb") as f: |
|
train_labels = pickle.load(f) |
|
with open("test_labels_chunked_by_family.pkl", "rb") as f: |
|
test_labels = pickle.load(f) |
|
|
|
# Tokenize the sequences |
|
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) |
|
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) |
|
|
|
# Truncate the labels to match the tokenized sequence lengths |
|
train_labels = truncate_labels(train_labels, max_sequence_length) |
|
test_labels = truncate_labels(test_labels, max_sequence_length) |
|
|
|
# Create train and test datasets |
|
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) |
|
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) |
|
``` |
|
|
|
Then run the following to get the train/test metrics: |
|
|
|
```python |
|
from sklearn.metrics import( |
|
matthews_corrcoef, |
|
accuracy_score, |
|
precision_recall_fscore_support, |
|
roc_auc_score |
|
) |
|
from peft import PeftModel |
|
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification |
|
from transformers import Trainer |
|
from accelerate import Accelerator |
|
|
|
# Instantiate the accelerator |
|
accelerator = Accelerator() |
|
|
|
# Define paths to the LoRA and base models |
|
base_model_path = "facebook/esm2_t12_35M_UR50D" |
|
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model |
|
|
|
# Load the base model |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
|
|
# Load the LoRA model |
|
model = PeftModel.from_pretrained(base_model, lora_model_path) |
|
model = accelerator.prepare(model) # Prepare the model using the accelerator |
|
|
|
# Define label mappings |
|
id2label = {0: "No binding site", 1: "Binding site"} |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
# Create a data collator |
|
data_collator = DataCollatorForTokenClassification(tokenizer) |
|
|
|
# Define a function to compute the metrics |
|
def compute_metrics(dataset): |
|
# Get the predictions using the trained model |
|
trainer = Trainer(model=model, data_collator=data_collator) |
|
predictions, labels, _ = trainer.predict(test_dataset=dataset) |
|
|
|
# Remove padding and special tokens |
|
mask = labels != -100 |
|
true_labels = labels[mask].flatten() |
|
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist() |
|
|
|
# Compute the metrics |
|
accuracy = accuracy_score(true_labels, flat_predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary') |
|
auc = roc_auc_score(true_labels, flat_predictions) |
|
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC |
|
|
|
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary |
|
|
|
# Get the metrics for the training and test datasets |
|
train_metrics = compute_metrics(train_dataset) |
|
test_metrics = compute_metrics(test_dataset) |
|
|
|
train_metrics, test_metrics |
|
``` |
|
|