#ref: https://huggingface.co/blog/AmelieSchreiber/esmbind import gradio as gr import os # os.environ["CUDA_VISIBLE_DEVICES"] = "0" #import wandb import numpy as np import torch import torch.nn as nn import pickle import xml.etree.ElementTree as ET from datetime import datetime from sklearn.model_selection import train_test_split from sklearn.utils.class_weight import compute_class_weight from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef ) from transformers import ( AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer ) from peft import PeftModel from datasets import Dataset from accelerate import Accelerator # Imports specific to the custom peft lora model from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType from plot_pdb import plot_struc def suggest(option): if option == "Plastic degradation protein": suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ" elif option == "Default protein": #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" elif option == "Antifreeze protein": suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" elif option == "AI Generated protein": suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" elif option == "7-bladed propeller fold": suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" else: suggestion = "" return suggestion # Helper Functions and Data Preparation def truncate_labels(labels, max_length): """Truncate labels to the specified max_length.""" return [label[:max_length] for label in labels] def compute_metrics(p): """Compute metrics for evaluation.""" predictions, labels = p predictions = np.argmax(predictions, axis=2) # Remove padding (-100 labels) predictions = predictions[labels != -100].flatten() labels = labels[labels != -100].flatten() # Compute accuracy accuracy = accuracy_score(labels, predictions) # Compute precision, recall, F1 score, and AUC precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') auc = roc_auc_score(labels, predictions) # Compute MCC mcc = matthews_corrcoef(labels, predictions) return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} def compute_loss(model, inputs): """Custom compute_loss function.""" logits = model(**inputs).logits labels = inputs["labels"] loss_fct = nn.CrossEntropyLoss(weight=class_weights) active_loss = inputs["attention_mask"].view(-1) == 1 active_logits = logits.view(-1, model.config.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) return loss # Define Custom Trainer Class # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer. class WeightedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): outputs = model(**inputs) loss = compute_loss(model, inputs) return (loss, outputs) if return_outputs else loss # Predict binding site with finetuned PEFT model def predict_bind(base_model_path,PEFT_model_path,input_seq): # Load the model base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path) # Ensure the model is in evaluation mode loaded_model.eval() # Tokenization tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Tokenize the sequence inputs = tokenizer(input_seq, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') # Run the model with torch.no_grad(): logits = loaded_model(**inputs).logits # Get predictions tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens predictions = torch.argmax(logits, dim=2) binding_site=[] pos = 0 # Print the predicted labels for each token for token, prediction in zip(tokens, predictions[0].numpy()): if token not in ['', '', '']: pos += 1 print((pos, token, id2label[prediction])) if prediction == 1: print((pos, token, id2label[prediction])) binding_site.append([pos, token, id2label[prediction]]) return binding_site # fine-tuning function def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset): # Set the LoRA config config = { "lora_alpha": 1, #try 0.5, 1, 2, ..., 16 "lora_dropout": 0.2, "lr": 5.701568055793089e-04, "lr_scheduler_type": "cosine", "max_grad_norm": 0.5, "num_train_epochs": 1, #3, jw 20240628 "per_device_train_batch_size": 12, "r": 2, "weight_decay": 0.2, # Add other hyperparameters as needed } base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id) # Tokenization tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D") train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels) test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels) # Convert the model into a PeftModel peft_config = LoraConfig( task_type=TaskType.TOKEN_CLS, inference_mode=False, r=config["r"], lora_alpha=config["lora_alpha"], target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h" lora_dropout=config["lora_dropout"], bias="none" # or "all" or "lora_only" ) base_model = get_peft_model(base_model, peft_config) # Use the accelerator base_model = accelerator.prepare(base_model) train_dataset = accelerator.prepare(train_dataset) test_dataset = accelerator.prepare(test_dataset) model_name_base = base_model_path.split("/")[1] timestamp = datetime.now().strftime('%Y-%m-%d_%H') save_path = f"{model_name_base}-lora-binding-sites_{timestamp}" # Training setup training_args = TrainingArguments( output_dir=save_path, #f"{model_name_base}-lora-binding-sites_{timestamp}", learning_rate=config["lr"], lr_scheduler_type=config["lr_scheduler_type"], gradient_accumulation_steps=1, max_grad_norm=config["max_grad_norm"], per_device_train_batch_size=config["per_device_train_batch_size"], per_device_eval_batch_size=config["per_device_train_batch_size"], num_train_epochs=config["num_train_epochs"], weight_decay=config["weight_decay"], evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, push_to_hub=True, #jw 20240701 False, logging_dir=None, logging_first_step=False, logging_steps=200, save_total_limit=7, no_cuda=False, seed=8893, fp16=True, #report_to='wandb' report_to=None, hub_token = HF_TOKEN, #jw 20240701 ) # Initialize Trainer trainer = WeightedTrainer( model=base_model, args=training_args, train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=tokenizer, data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer), compute_metrics=compute_metrics, ) # Train and Save Model trainer.train() return save_path # Constants & Globals HF_TOKEN = os.environ.get("HF_token") print("HF_TOKEN:",HF_TOKEN) MODEL_OPTIONS = [ "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t33_650M_UR50D", ] # models users can choose from PEFT_MODEL_OPTIONS = [ "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54", "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3", ] # finetuned models # Load the data from pickle files (replace with your local paths) with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f: train_sequences = pickle.load(f) with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f: test_sequences = pickle.load(f) with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f: train_labels = pickle.load(f) with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f: test_labels = pickle.load(f) max_sequence_length = 1000 # Directly truncate the entire list of labels train_labels = truncate_labels(train_labels, max_sequence_length) test_labels = truncate_labels(test_labels, max_sequence_length) # Compute Class Weights classes = [0, 1] flat_train_labels = [label for sublist in train_labels for label in sublist] class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels) accelerator = Accelerator() class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device) # Define labels and model id2label = {0: "No binding site", 1: "Binding site"} label2id = {v: k for k, v in id2label.items()} ''' # debug result dubug_result = saved_path #predictions #class_weights ''' demo = gr.Blocks(title="DEMO FOR ESM2Bind") with demo: gr.Markdown("# DEMO FOR ESM2Bind") #gr.Textbox(dubug_result) with gr.Column(): gr.Markdown("## Select a base model and a corresponding PEFT finetune model") with gr.Row(): with gr.Column(scale=5, variant="compact"): base_model_name = gr.Dropdown( choices=MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model Name", interactive = True, ) PEFT_model_name = gr.Dropdown( choices=PEFT_MODEL_OPTIONS, value=PEFT_MODEL_OPTIONS[0], label="PEFT Model Name", interactive = True, ) with gr.Column(scale=5, variant="compact"): name = gr.Dropdown( label="Choose a Sample Protein", value="Default protein", choices=["Default protein", "Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"] ) gr.Markdown( "## Predict binding site and Plot structure for selected protein sequence:" ) with gr.Row(): with gr.Column(variant="compact", scale = 8): input_seq = gr.Textbox( lines=1, max_lines=12, label="Protein sequency to be predicted:", value="MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT", placeholder="Paste your protein sequence here...", interactive = True, ) text_pos = gr.Textbox( lines=1, max_lines=12, label="Sequency Position:", placeholder= "012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789", interactive=False, ) with gr.Column(variant="compact", scale = 2): predict_btn = gr.Button( value="Predict binding site", interactive=True, variant="primary", ) plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary") with gr.Row(): with gr.Column(variant="compact", scale = 5): output_text = gr.Textbox( lines=1, max_lines=12, label="Output", placeholder="Output", ) with gr.Column(variant="compact", scale = 5): finetune_button = gr.Button( value="Finetune Pre-trained Model", interactive=True, variant="primary", ) with gr.Row(): output_viewer = gr.HTML() output_file = gr.File( label="Download as Text File", file_count="single", type="filepath", interactive=False, ) # select protein sample name.change(fn=suggest, inputs=name, outputs=input_seq) # "Predict binding site" actions predict_btn.click( fn = predict_bind, inputs=[base_model_name,PEFT_model_name,input_seq], outputs = [output_text], ) # "Finetune Pre-trained Model" actions finetune_button.click( fn = train_function_no_sweeps, inputs=[base_model_name], outputs = [output_text], ) # plot protein structure plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer]) demo.launch()