--- license: apache-2.0 --- # InterProt ESM2 SAE Models A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function, though with much fewer latent dimensions. Check out [https://interprot.com](https://interprot.com) for an interactive visualizer of the 4096-dimensional SAE on ESM layer 24. ## Installation ```bash pip install git+https://github.com/etowahadams/interprot.git ``` ## Usage Load the SAE ```python from safetensors.torch import load_file from interprot.sae_model import SparseAutoencoder sae_model = SparseAutoencoder(1280, 4096) checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors' sae_model.load_state_dict(load_file(checkpoint_path)) ``` Load ESM and run ESM inference -> SAE inference ``` import torch from transformers import AutoTokenizer, EsmModel # Load ESM model and tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") # Run ESM inference with some sequence and take layer 24 activations seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG" esm_layer = 24 inputs = tokenizer([seq], padding=True, return_tensors="pt") with torch.no_grad(): outputs = esm_model(**inputs, output_hidden_states=True) esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280) # Run SAE inference with ESM activations as input sae_acts = sae_model.get_acts(esm_layer_acts) sae_acts # (1, sequence length + 2, 4096) ```