Geneformer
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
Abstract
Mapping gene networks requires large amounts of transcriptomic data to learn the connections between genes, which impedes discoveries in settings with limited data, including rare diseases and diseases affecting clinically inaccessible tissues. Recently, transfer learning has revolutionized fields such as natural language understanding1,2 and computer vision3 by leveraging deep learning models pretrained on large-scale general datasets that can then be fine-tuned towards a vast array of downstream tasks with limited task-specific data. Here, we developed a context-aware, attention-based deep learning model, Geneformer, pretrained on a large-scale corpus of about 30 million single-cell transcriptomes to enable context-specific predictions in settings with limited data in network biology. During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the attention weights of the model in a completely self-supervised manner. Fine-tuning towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modelling with limited patient data, Geneformer identified candidate therapeutic targets for cardiomyopathy. Overall, Geneformer represents a pretrained deep learning model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
Code
from tdc.model_server.tokenizers.geneformer import GeneformerTokenizer
from tdc import tdc_hf_interface
import torch
# Retrieve anndata object. Then, tokenize
tokenizer = GeneformerTokenizer()
x = tokenizer.tokenize_cell_vectors(adata,
ensembl_id="feature_id",
ncounts="n_measured_vars")
cells, _ = x
input_tensor = torch.tensor(cells) # note that you may need to pad or perform other custom data processing
# retrieve model
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
# run inference
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in input_tensor]) # here we assume we used 0/False as a special padding token
outputs = model(batch,
attention_mask=attention_mask,
output_hidden_states=True)
layer_to_quant = quant_layers(model) + (
-1
) # Geneformer's second-to-last layer is most generalized
embs_i = outputs.hidden_states[layer_to_quant]
# there are "cls", "cell", and "gene" embeddings. we will only capture "gene", which is cell type specific. for "cell", you'd average out across unmasked gene embeddings per cell
embs = embs_i
TDC Citation
@inproceedings{
velez-arce2024signals,
title={Signals in the Cells: Multimodal and Contextualized Machine Learning Foundations for Therapeutics},
author={Alejandro Velez-Arce and Kexin Huang and Michelle M Li and Xiang Lin and Wenhao Gao and Bradley Pentelute and Tianfan Fu and Manolis Kellis and Marinka Zitnik},
booktitle={NeurIPS 2024 Workshop on AI for New Drug Modalities},
year={2024},
url={https://openreview.net/forum?id=kL8dlYp6IM}
}
Additional Citations
- C V Theodoris#, L Xiao, A Chopra, M D Chaffin, Z R Al Sayed, M C Hill, H Mantineo, E Brydon, Z Zeng, X S Liu, P T Ellinor#. Transfer learning enables predictions in network biology. Nature, 31 May 2023. (#co-corresponding authors)
- H Chen*, M S Venkatesh*, J Gomez Ortega, S V Mahesh, T Nandi, R Madduri, K Pelka†, C V Theodoris†#. Quantized multi-task learning for context-specific representations of gene network dynamics. bioRxiv, 19 Aug 2024. (*co-first authors, †co-senior authors, #corresponding author)
Model HF Homepage
- Downloads last month
- 22