Overview
OneProt is a multimodal model that integrates protein sequence, protein structure (both in form of an augmented sequence and in a form of a graph), protein binding sites and protein text annotations. Contrastive learning is used to align each of the modality to the central one, which is protein sequence. In the pre-training phase InfoNCE loss is computed between pairs (protein sequence, other modality).
Model architecture
Protein sequence encoder: esm2_t33_650M_UR50D
Protein structure encoder: esm2_t12_35M_UR50D
Protein structure encoder GNN: ProNet
Pocket (binding sites encoder) GNN: ProNet
Text encoder: BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
Below is an example code on how to obtain the embeddings (requires cloning our repo first). Note that example data for transformer models are read-off from .txt
files and in principle can be passed as strings, whlist the data for GNN models are contained in the example .h5
file and need to subsequently be converted to graphs.
import torch
import hydra
from omegaconf import OmegaConf
from huggingface_hub import HfApi, hf_hub_download
import sys
import os
import h5py
from torch_geometric.data import Batch
from transformers import AutoTokenizer
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # assuming that you are running this script from the oneprot repo, can be any other path
from src.models.oneprot_module import OneProtLitModule
from src.data.utils.struct_graph_utils import protein_to_graph
#Load the config file and read it off
config_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="config.yaml",
)
with open(config_path, 'r') as f:
cfg = OmegaConf.load(f)
# Prepare components dictionary from config
components = {
'sequence': hydra.utils.instantiate(cfg.model.components.sequence),
'struct_token': hydra.utils.instantiate(cfg.model.components.struct_token),
'struct_graph': hydra.utils.instantiate(cfg.model.components.struct_graph),
'pocket': hydra.utils.instantiate(cfg.model.components.pocket),
'text': hydra.utils.instantiate(cfg.model.components.text)
}
# Load the model checkpoint
checkpoint_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="pytorch_model.bin",
repo_type="model"
)
# Create model instance and load the checkpoint
model = OneProtLitModule(
components=components,
optimizer=None,
loss_fn=cfg.model.loss_fn,
local_loss=cfg.model.local_loss,
gather_with_grad=cfg.model.gather_with_grad,
use_l1_regularization=cfg.model.use_l1_regularization,
train_on_all_modalities_after_step=cfg.model.train_on_all_modalities_after_step,
use_seqsim=cfg.model.use_seqsim
)
state_dict = torch.load(checkpoint_path)
model_state_dict = model.state_dict()
model.load_state_dict(state_dict, strict=True)
# Define the tokenisers
tokenizers = {
'sequence': "facebook/esm2_t33_650M_UR50D",
'struct_token': "facebook/esm2_t33_650M_UR50D",
'text': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
}
loaded_tokenizers = {}
for modality, tokenizer_name in tokenizers.items():
tokenizer = AutoTokenizer.from_pretrained(tokenizers[modality])
loaded_tokenizers[modality] = tokenizer
# Get example embeddings for each modality
##########################sequence##############################
modality = "sequence"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/sequence_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_sequence = file.read().strip()
input_tensor = loaded_tokenizers[modality](input_sequence, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
###########################text#################################
modality = "text"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/text_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_text = file.read().strip()
input_tensor = loaded_tokenizers[modality](input_text, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
#####################tokenized structure########################
modality = "struct_token"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/struct_token_example.txt",
repo_type="model" # or "dataset"
)
with open(file_path, 'r') as file:
input_struct_token = file.read().strip()
input_struct_token = "".join([s.replace("#", "") for s in input_struct_token])
input_tensor = loaded_tokenizers[modality](input_struct_token, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")
#####################graph structure############################
modality = "struct_graph"
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/seqstruc_example.h5",
repo_type="model" # or "dataset"
)
with h5py.File(file_path, 'r') as file:
input_struct_graph=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=False)]
input_struct_graph = Batch.from_data_list(input_struct_graph)
output=model.network[modality](input_struct_graph)
print(f"Output for modality '{modality}': {output}")
##########################pocket################################
modality = "pocket" # Replace with the desired modality
file_path = hf_hub_download(
repo_id="HelmholtzAI-FZJ/oneprot",
filename="data_examples/pocket_example.h5",
repo_type="model" # or "dataset"
)
with h5py.File(file_path, 'r') as file:
input_pocket=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=True)]
input_pocket = Batch.from_data_list(input_pocket)
output=model.network[modality](input_pocket)
print(f"Output for modality '{modality}': {output}")
Citation
@misc{flöge2024oneprotmultimodalproteinfoundation,
title={OneProt: Towards Multi-Modal Protein Foundation Models},
author={Klemens Flöge and Srisruthi Udayakumar and Johanna Sommer and Marie Piraud and Stefan Kesselheim and Vincent Fortuin and Stephan Günneman and Karel J van der Weg and Holger Gohlke and Alina Bazarova and Erinc Merdivan},
year={2024},
eprint={2411.04863},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.04863},
}
- Downloads last month
- 2