metadata
widget:
- text: >-
MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF
example_title: Protein Sequence 1
- text: >-
MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY
example_title: Protein Sequence 2
- text: >-
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY
example_title: Protein Sequence 3
license: mit
language:
- en
metrics:
- f1
- accuracy
- precision
- recall
- matthews_correlation
- roc_auc
library_name: peft
tags:
- ESM-2
- protein language model
- biology
- binding sites
Training:
For a report on the training please see here and here.
Metrics:
Train:
({'accuracy': 0.9406146072672105,
'precision': 0.2947122459102886,
'recall': 0.952624323712029,
'f1': 0.4501592605994876,
'auc': 0.9464622170085311,
'mcc': 0.5118390407598565},
Test:
{'accuracy': 0.9266827008067329,
'precision': 0.22378953253253775,
'recall': 0.7790246675002842,
'f1': 0.3476966444342296,
'auc': 0.8547531675185658,
'mcc': 0.3930283737012391})
Using the Model
Using on your Protein Sequences
To use the model on one of your protein sequences try running the following:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
Getting the Train/Test Metrics:
Head over to here to download the dataset first. Once you have the pickle files downloaded locally, run the following:
from datasets import Dataset
from transformers import AutoTokenizer
import pickle
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
# Function to truncate labels
def truncate_labels(labels, max_length):
"""Truncate labels to the specified max_length."""
return [label[:max_length] for label in labels]
# Set the maximum sequence length
max_sequence_length = 1000
# Load the data from pickle files
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)
# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
Then run the following to get the train/test metrics:
from sklearn.metrics import(
matthews_corrcoef,
accuracy_score,
precision_recall_fscore_support,
roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator
# Instantiate the accelerator
accelerator = Accelerator()
# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model
# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model) # Prepare the model using the accelerator
# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}
# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)
# Define a function to compute the metrics
def compute_metrics(dataset):
# Get the predictions using the trained model
trainer = Trainer(model=model, data_collator=data_collator)
predictions, labels, _ = trainer.predict(test_dataset=dataset)
# Remove padding and special tokens
mask = labels != -100
true_labels = labels[mask].flatten()
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
# Compute the metrics
accuracy = accuracy_score(true_labels, flat_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
auc = roc_auc_score(true_labels, flat_predictions)
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary
# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)
train_metrics, test_metrics