AmelieSchreiber commited on
Commit
6187032
1 Parent(s): a728b2c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -1
README.md CHANGED
@@ -1,14 +1,83 @@
1
  ---
2
  library_name: peft
3
  license: mit
 
 
 
 
 
 
 
 
 
4
  ---
5
 
6
  ## Training procedure
 
 
 
 
 
7
  ```
8
  Epoch Training Loss Validation Loss Accuracy Precision Recall F1 Auc Mcc
9
  1 0.037400 0.301413 0.939431 0.366282 0.833003 0.508826 0.888300 0.528311
10
  ```
11
- ### Framework versions
12
 
 
 
 
 
 
 
13
 
14
  - PEFT 0.5.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
  license: mit
4
+ datasets:
5
+ - AmelieSchreiber/binding_sites_random_split_by_family_550K
6
+ metrics:
7
+ - accuracy
8
+ - f1
9
+ - roc_auc
10
+ - precision
11
+ - recall
12
+ - matthews_correlation
13
  ---
14
 
15
  ## Training procedure
16
+
17
+ This model was finetuned on ~549K protein sequences from the UniProt database. The dataset can be found
18
+ [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). The model obtains
19
+ the following test metrics:
20
+
21
  ```
22
  Epoch Training Loss Validation Loss Accuracy Precision Recall F1 Auc Mcc
23
  1 0.037400 0.301413 0.939431 0.366282 0.833003 0.508826 0.888300 0.528311
24
  ```
 
25
 
26
+ The dataset size increase from ~209K protein sequences to ~549K clearly improved performance in terms of test metric.
27
+ We used Hugging Face's parameter efficient finetuning (PEFT) library to finetune with Low Rank Adaptation (LoRA). We decided
28
+ to use a rank of 2 for the LoRA, as this was shown to slightly improve the test metrics compared to rank 8 and rank 16 on the
29
+ same model trained on the smaller dataset.
30
+
31
+ ### Framework versions
32
 
33
  - PEFT 0.5.0
34
+
35
+ ## Using the model
36
+
37
+ To use the model on one of your protein sequences try running the following:
38
+
39
+ ```python
40
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
41
+ from peft import PeftModel
42
+ import torch
43
+
44
+ # Path to the saved LoRA model
45
+ model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1"
46
+ # ESM2 base model
47
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
48
+
49
+ # Load the model
50
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
51
+ loaded_model = PeftModel.from_pretrained(base_model, model_path)
52
+
53
+ # Ensure the model is in evaluation mode
54
+ loaded_model.eval()
55
+
56
+ # Load the tokenizer
57
+ loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
58
+
59
+ # Protein sequence for inference
60
+ protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
61
+
62
+ # Tokenize the sequence
63
+ inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
64
+
65
+ # Run the model
66
+ with torch.no_grad():
67
+ logits = loaded_model(**inputs).logits
68
+
69
+ # Get predictions
70
+ tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
71
+ predictions = torch.argmax(logits, dim=2)
72
+
73
+ # Define labels
74
+ id2label = {
75
+ 0: "No binding site",
76
+ 1: "Binding site"
77
+ }
78
+
79
+ # Print the predicted labels for each token
80
+ for token, prediction in zip(tokens, predictions[0].numpy()):
81
+ if token not in ['<pad>', '<cls>', '<eos>']:
82
+ print((token, id2label[prediction]))
83
+ ```