--- license: mit --- ## Training: For a report on the training [please see here](https://wandb.ai/amelie-schreiber-math/huggingface/reports/ESM-2-Binding-Sites-Predictor-Scaling-Up--Vmlldzo1Mzc3MTAz?accessToken=cbl9v3bvuq65j5t4qo9l0bhccm3hrse8nt01t3dka6h6zb0azzakahnxdxfrb28m) ## Metrics: ```python Train: ({'accuracy': 0.9406146072672105, 'precision': 0.2947122459102886, 'recall': 0.952624323712029, 'f1': 0.4501592605994876, 'auc': 0.9464622170085311, 'mcc': 0.5118390407598565}, Test: {'accuracy': 0.9266827008067329, 'precision': 0.22378953253253775, 'recall': 0.7790246675002842, 'f1': 0.3476966444342296, 'auc': 0.8547531675185658, 'mcc': 0.3930283737012391}) ``` ## Using the Model Head over to [here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family) to download the dataset first. Once you have the pickle files downloaded locally, run the following: ```python from datasets import Dataset from transformers import AutoTokenizer import pickle # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") # Function to truncate labels def truncate_labels(labels, max_length): """Truncate labels to the specified max_length.""" return [label[:max_length] for label in labels] # Set the maximum sequence length max_sequence_length = 1000 # Load the data from pickle files with open("train_sequences_chunked_by_family.pkl", "rb") as f: train_sequences = pickle.load(f) with open("test_sequences_chunked_by_family.pkl", "rb") as f: test_sequences = pickle.load(f) with open("train_labels_chunked_by_family.pkl", "rb") as f: train_labels = pickle.load(f) with open("test_labels_chunked_by_family.pkl", "rb") as f: test_labels = pickle.load(f) # Tokenize the sequences train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) # Truncate the labels to match the tokenized sequence lengths train_labels = truncate_labels(train_labels, max_sequence_length) test_labels = truncate_labels(test_labels, max_sequence_length) # Create train and test datasets train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) ``` Then run the following to get the train/test metrics: ```python from sklearn.metrics import( matthews_corrcoef, accuracy_score, precision_recall_fscore_support, roc_auc_score ) from peft import PeftModel from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification from transformers import Trainer from accelerate import Accelerator # Instantiate the accelerator accelerator = Accelerator() # Define paths to the LoRA and base models base_model_path = "facebook/esm2_t12_35M_UR50D" lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model # Load the base model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) # Load the LoRA model model = PeftModel.from_pretrained(base_model, lora_model_path) model = accelerator.prepare(model) # Prepare the model using the accelerator # Define label mappings id2label = {0: "No binding site", 1: "Binding site"} label2id = {v: k for k, v in id2label.items()} # Create a data collator data_collator = DataCollatorForTokenClassification(tokenizer) # Define a function to compute the metrics def compute_metrics(dataset): # Get the predictions using the trained model trainer = Trainer(model=model, data_collator=data_collator) predictions, labels, _ = trainer.predict(test_dataset=dataset) # Remove padding and special tokens mask = labels != -100 true_labels = labels[mask].flatten() flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist() # Compute the metrics accuracy = accuracy_score(true_labels, flat_predictions) precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary') auc = roc_auc_score(true_labels, flat_predictions) mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} # Include the MCC in the returned dictionary # Get the metrics for the training and test datasets train_metrics = compute_metrics(train_dataset) test_metrics = compute_metrics(test_dataset) train_metrics, test_metrics ```