InterProt ESM2 SAE Models

A set of SAE models trained on ESM2-650 activations using 1M protein sequences from UniProt. The SAE implementation mostly followed Gao et al. with Top-K activation function, though with much fewer latent dimensions.

Check out https://interprot.com for an interactive visualizer of the 4096-dimensional SAE on ESM layer 24.

Installation

pip install git+https://github.com/etowahadams/interprot.git

Usage

Load the SAE

from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder

sae_model = SparseAutoencoder(1280, 4096)
checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors'
sae_model.load_state_dict(load_file(checkpoint_path))

Load ESM and run ESM inference -> SAE inference

import torch
from transformers import AutoTokenizer, EsmModel

# Load ESM model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")

# Run ESM inference with some sequence and take layer 24 activations
seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
esm_layer = 24

inputs = tokenizer([seq], padding=True, return_tensors="pt")
with torch.no_grad():
    outputs = esm_model(**inputs, output_hidden_states=True)
esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280)

# Run SAE inference with ESM activations as input
sae_acts = sae_model.get_acts(esm_layer_acts)
sae_acts # (1, sequence length + 2, 4096)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .