File size: 8,330 Bytes
8fde187 5f80fcb 8fde187 1f043b2 8fde187 42f78dd b525259 7fe197f b525259 6b4ca98 42f78dd 6b4ca98 0c083d0 6b4ca98 1f043b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
---
widget:
- text: "MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF"
example_title: "Protein Sequence 1"
- text: "MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY"
example_title: "Protein Sequence 2"
- text: "MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY"
example_title: "Protein Sequence 3"
license: mit
language:
- en
metrics:
- f1
- accuracy
- precision
- recall
- matthews_correlation
- roc_auc
library_name: peft
tags:
- ESM-2
- protein language model
- biology
- binding sites
---
## Training:
For a report on the training [please see here](https://api.wandb.ai/links/amelie-schreiber-math/84t5gsfm) and
[here](https://wandb.ai/amelie-schreiber-math/huggingface/reports/ESM-2-Binding-Sites-Predictor-Scaling-Up--Vmlldzo1Mzc3MTAz?accessToken=cbl9v3bvuq65j5t4qo9l0bhccm3hrse8nt01t3dka6h6zb0azzakahnxdxfrb28m).
## Metrics:
```python
Train:
({'accuracy': 0.9406146072672105,
'precision': 0.2947122459102886,
'recall': 0.952624323712029,
'f1': 0.4501592605994876,
'auc': 0.9464622170085311,
'mcc': 0.5118390407598565},
Test:
{'accuracy': 0.9266827008067329,
'precision': 0.22378953253253775,
'recall': 0.7790246675002842,
'f1': 0.3476966444342296,
'auc': 0.8547531675185658,
'mcc': 0.3930283737012391})
```
## Using the Model
### Using on your Protein Sequences
To use the model on one of your protein sequences try running the following:
```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
```
### Getting the Train/Test Metrics:
Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family)
to download the dataset first. Once you have the pickle files downloaded locally, run the following:
```python
from datasets import Dataset
from transformers import AutoTokenizer
import pickle
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
# Function to truncate labels
def truncate_labels(labels, max_length):
"""Truncate labels to the specified max_length."""
return [label[:max_length] for label in labels]
# Set the maximum sequence length
max_sequence_length = 1000
# Load the data from pickle files
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)
# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
```
Then run the following to get the train/test metrics:
```python
from sklearn.metrics import(
matthews_corrcoef,
accuracy_score,
precision_recall_fscore_support,
roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator
# Instantiate the accelerator
accelerator = Accelerator()
# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model
# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model) # Prepare the model using the accelerator
# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}
# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)
# Define a function to compute the metrics
def compute_metrics(dataset):
# Get the predictions using the trained model
trainer = Trainer(model=model, data_collator=data_collator)
predictions, labels, _ = trainer.predict(test_dataset=dataset)
# Remove padding and special tokens
mask = labels != -100
true_labels = labels[mask].flatten()
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
# Compute the metrics
accuracy = accuracy_score(true_labels, flat_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
auc = roc_auc_score(true_labels, flat_predictions)
mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary
# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)
train_metrics, test_metrics
``` |