ESM-2 for Binding Site Prediction
This model is overfit (see below). This model is a finetuned version of the 35M parameter esm2_t12_35M_UR50D
(see here
and here for more details). The model was finetuned with LoRA for
the binay token classification task of predicting binding sites (and active sites) of protein sequences based on sequence alone.
The model may be underfit and undertrained, however it still achieved better performance on the test set in terms of loss, accuracy,
precision, recall, F1 score, ROC_AUC, and Matthews Correlation Coefficient (MCC) compared to the models trained on the smaller
dataset found here of ~209K protein sequences. Note,
this model has a high recall, meaning it is likely to detect binding sites, but it has a low precision, meaning the model will likely return
false positives as well.
Training procedure
This model was finetuned on ~549K protein sequences from the UniProt database. The dataset can be found here. The model obtains the following test metrics:
Train: ({'accuracy': 0.9905461579981686,
'precision': 0.7695765003685506,
'recall': 0.9841352974610041,
'f1': 0.8637307441810476,
'auc': 0.9874413786006525,
'mcc': 0.8658850560635515},
Test: {'accuracy': 0.9394282959813123,
'precision': 0.3662722265170941,
'recall': 0.8330231316088238,
'f1': 0.5088208423175958,
'auc': 0.8883078682492643,
'mcc': 0.5283098562376193})
To analyze the train and test metrics, we will consider each metric individually and then offer a comprehensive view of the model’s performance. Let's start:
1. Accuracy
- Train: 99.05%
- Test: 93.94%
The accuracy is quite high in both the training and test datasets, indicating that the model is correctly identifying the positive and negative classes most of the time.
2. Precision
- Train: 76.96%
- Test: 36.63%
The precision, which measures the proportion of true positive predictions among all positive predictions, drops significantly in the test set. This suggests that the model might be identifying too many false positives when generalized to unseen data.
3. Recall
- Train: 98.41%
- Test: 83.30%
The recall, which indicates the proportion of actual positives correctly identified, remains quite high in the test set, although lower than in the training set. This suggests the model is quite sensitive and is able to identify most of the positive cases.
4. F1-Score
- Train: 86.37%
- Test: 50.88%
The F1-score is the harmonic mean of precision and recall. The significant drop in the F1-score from training to testing indicates that the balance between precision and recall has worsened in the test set, which is primarily due to the lower precision.
5. AUC (Area Under the ROC Curve)
- Train: 98.74%
- Test: 88.83%
The AUC is high in both training and testing, but it decreases in the test set. A high AUC indicates that the model has good measure of separability and is able to distinguish between the positive and negative classes well.
6. MCC (Matthews Correlation Coefficient)
- Train: 86.59%
- Test: 52.83%
MCC is a balanced metric that considers true and false positives and negatives. The decline in MCC from training to testing indicates a decrease in the quality of binary classifications.
Overall Analysis
Overfitting: The significant drop in metrics such as precision, F1-score, and MCC from training to test set suggests that the model might be overfitting to the training data, i.e., it may not generalize well to unseen data.
High Recall, Low Precision: The model has a high recall but low precision on the test set, indicating that it is identifying too many cases as positive, including those that are actually negative (false positives). This could be a reflection of a model that is biased towards predicting the positive class.
Improvement Suggestions:
- Data Augmentation: So, we might want to consider data augmentation strategies to make the model more robust.
- Class Weights: If there is a class imbalance in the dataset, adjusting the class weights during training might help.
- Hyperparameter Tuning: Experiment with different hyperparameters, including the learning rate, batch size, etc., to see if you can improve the model's performance on the test set.
- Feature Engineering: Consider revisiting the features used to train the model. Sometimes, introducing new features or removing irrelevant ones can help improve performance.
In conclusion, while the model performs excellently on the training set, its performance drops in the test set, suggesting that there is room for improvement to make the model more generalizable to unseen data. It would be beneficial to look into strategies to reduce overfitting and improve precision without significantly sacrificing recall.
The dataset size increase from ~209K protein sequences to ~549K clearly improved performance in terms of test metric. We used Hugging Face's parameter efficient finetuning (PEFT) library to finetune with Low Rank Adaptation (LoRA). We decided to use a rank of 2 for the LoRA, as this was shown to slightly improve the test metrics compared to rank 8 and rank 16 on the same model trained on the smaller dataset.
Framework versions
- PEFT 0.5.0
Using the model
To use the model on one of your protein sequences try running the following:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
- Downloads last month
- 51
Model tree for AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1
Base model
facebook/esm2_t12_35M_UR50D