File size: 7,586 Bytes
3694457 7a1fb87 59c4ed1 7a1fb87 6187032 a29fcb1 7a1fb87 59c4ed1 0915bd1 49e961f cd1f4b4 5e8cce3 49e961f 400bbb0 49e961f 3694457 6187032 49e961f 6187032 17b2397 d4048c3 cd1f4b4 d4048c3 cd1f4b4 a728b2c 3694457 d4048c3 6187032 3694457 6187032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
---
language:
- en
license: mit
library_name: peft
tags:
- ESM-2
- protein language model
- binding sites
- biology
datasets:
- AmelieSchreiber/binding_sites_random_split_by_family_550K
metrics:
- accuracy
- f1
- roc_auc
- precision
- recall
- matthews_correlation
pipeline_tag: token-classification
base_model: facebook/esm2_t12_35M_UR50D
---
# ESM-2 for Binding Site Prediction
**This model is overfit (see below).** This model is a finetuned version of the 35M parameter `esm2_t12_35M_UR50D` ([see here](https://huggingface.co/facebook/esm2_t12_35M_UR50D)
and [here](https://huggingface.co/docs/transformers/model_doc/esm) for more details). The model was finetuned with LoRA for
the binay token classification task of predicting binding sites (and active sites) of protein sequences based on sequence alone.
The model may be underfit and undertrained, however it still achieved better performance on the test set in terms of loss, accuracy,
precision, recall, F1 score, ROC_AUC, and Matthews Correlation Coefficient (MCC) compared to the models trained on the smaller
dataset [found here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) of ~209K protein sequences. Note,
this model has a high recall, meaning it is likely to detect binding sites, but it has a low precision, meaning the model will likely return
false positives as well.
## Training procedure
This model was finetuned on ~549K protein sequences from the UniProt database. The dataset can be found
[here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). The model obtains
the following test metrics:
```python
Train: ({'accuracy': 0.9905461579981686,
'precision': 0.7695765003685506,
'recall': 0.9841352974610041,
'f1': 0.8637307441810476,
'auc': 0.9874413786006525,
'mcc': 0.8658850560635515},
Test: {'accuracy': 0.9394282959813123,
'precision': 0.3662722265170941,
'recall': 0.8330231316088238,
'f1': 0.5088208423175958,
'auc': 0.8883078682492643,
'mcc': 0.5283098562376193})
```
To analyze the train and test metrics, we will consider each metric individually and then offer a comprehensive view of the
model’s performance. Let's start:
### **1. Accuracy**
- **Train**: 99.05%
- **Test**: 93.94%
The accuracy is quite high in both the training and test datasets, indicating that the model is correctly identifying the positive
and negative classes most of the time.
### **2. Precision**
- **Train**: 76.96%
- **Test**: 36.63%
The precision, which measures the proportion of true positive predictions among all positive predictions, drops significantly in
the test set. This suggests that the model might be identifying too many false positives when generalized to unseen data.
### **3. Recall**
- **Train**: 98.41%
- **Test**: 83.30%
The recall, which indicates the proportion of actual positives correctly identified, remains quite high in the test set, although
lower than in the training set. This suggests the model is quite sensitive and is able to identify most of the positive cases.
### **4. F1-Score**
- **Train**: 86.37%
- **Test**: 50.88%
The F1-score is the harmonic mean of precision and recall. The significant drop in the F1-score from training to testing indicates
that the balance between precision and recall has worsened in the test set, which is primarily due to the lower precision.
### **5. AUC (Area Under the ROC Curve)**
- **Train**: 98.74%
- **Test**: 88.83%
The AUC is high in both training and testing, but it decreases in the test set. A high AUC indicates that the model has good measure
of separability and is able to distinguish between the positive and negative classes well.
### **6. MCC (Matthews Correlation Coefficient)**
- **Train**: 86.59%
- **Test**: 52.83%
MCC is a balanced metric that considers true and false positives and negatives. The decline in MCC from training to testing indicates
a decrease in the quality of binary classifications.
### **Overall Analysis**
- **Overfitting**: The significant drop in metrics such as precision, F1-score, and MCC from training to test set suggests that the model might be overfitting to the training data, i.e., it may not generalize well to unseen data.
- **High Recall, Low Precision**: The model has a high recall but low precision on the test set, indicating that it is identifying too many cases as positive, including those that are actually negative (false positives). This could be a reflection of a model that is biased towards predicting the positive class.
- **Improvement Suggestions**:
- **Data Augmentation**: So, we might want to consider data augmentation strategies to make the model more robust.
- **Class Weights**: If there is a class imbalance in the dataset, adjusting the class weights during training might help.
- **Hyperparameter Tuning**: Experiment with different hyperparameters, including the learning rate, batch size, etc., to see if you can improve the model's performance on the test set.
- **Feature Engineering**: Consider revisiting the features used to train the model. Sometimes, introducing new features or removing irrelevant ones can help improve performance.
In conclusion, while the model performs excellently on the training set, its performance drops in the test set, suggesting that there
is room for improvement to make the model more generalizable to unseen data. It would be beneficial to look into strategies to reduce
overfitting and improve precision without significantly sacrificing recall.
The dataset size increase from ~209K protein sequences to ~549K clearly improved performance in terms of test metric.
We used Hugging Face's parameter efficient finetuning (PEFT) library to finetune with Low Rank Adaptation (LoRA). We decided
to use a rank of 2 for the LoRA, as this was shown to slightly improve the test metrics compared to rank 8 and rank 16 on the
same model trained on the smaller dataset.
### Framework versions
- PEFT 0.5.0
## Using the model
To use the model on one of your protein sequences try running the following:
```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
``` |