""" ESM++ model implementation. ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility The ESM Python package is not required Modified from https://github.com/evolutionaryscale/esm License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement """ import math import os import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from functools import cache, partial from pathlib import Path from typing import Optional, Tuple, Union, List, Callable, Dict from einops import rearrange, repeat from huggingface_hub import snapshot_download from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.processors import TemplateProcessing from torch.utils.data import Dataset as TorchDataset from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import PreTrainedModel, PreTrainedTokenizerFast, PreTrainedTokenizerBase, PretrainedConfig from transformers.modeling_outputs import ModelOutput class ESMplusplusConfig(PretrainedConfig): """Configuration class for ESM++ model. Args: vocab_size: Size of the vocabulary hidden_size: Dimension of hidden layers num_attention_heads: Number of attention heads num_hidden_layers: Number of transformer layers num_labels: Number of output labels for classification problem_type: Type of problem - regression, single/multi label classification """ model_type = "ESMplusplus" def __init__( self, vocab_size: int = 64, hidden_size: int = 960, num_attention_heads: int = 15, num_hidden_layers: int = 30, num_labels: int = 2, problem_type: str | None = None, dropout: float = 0.0, initializer_range: float = 0.02, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_hidden_layers = num_hidden_layers self.num_labels = num_labels self.problem_type = problem_type self.dropout = dropout self.initializer_range = initializer_range ### Rotary Embeddings def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: """Rotates half the hidden dims of the input.""" if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange( torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 ) def apply_rotary_emb_torch( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, _inplace: bool = False, ) -> torch.Tensor: """Apply rotary embeddings to input based on cos and sin.""" ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] seqlen = x.size(1) cos = cos[:seqlen] sin = sin[:seqlen] cos = repeat(cos, "s d -> s 1 (2 d)") sin = repeat(sin, "s d -> s 1 (2 d)") return torch.cat( [ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:], ], dim=-1, ) class RotaryEmbedding(torch.nn.Module): """Rotary position embeddings. Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding" Args: dim: Dimension of the embedding base: Base for computing angular frequencies interleaved: Whether to use interleaved rotations scale_base: Base for scaling scaling_factor: Factor for scaling positions pos_idx_in_fp32: Whether to compute position indices in fp32 device: Computation device """ def __init__( self, dim: int, base: float = 10000.0, interleaved: bool = False, scale_base: Optional[float] = None, scaling_factor: float = 1.0, pos_idx_in_fp32: bool = True, device: Optional[torch.device] = None, ): super().__init__() self.dim = dim self.base = float(base) self.pos_idx_in_fp32 = pos_idx_in_fp32 self.interleaved = interleaved self.scale_base = scale_base self.scaling_factor = scaling_factor self.device = device self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None self.reset_parameters() def reset_parameters(self): """Reset the parameters of the embedding.""" inv_freq = self._compute_inv_freq(self.device) self.register_buffer("inv_freq", inv_freq, persistent=False) arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) scale = ( (arange + 0.4 * self.dim) / (1.4 * self.dim) if self.scale_base is not None else None ) self.register_buffer("scale", scale) def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor: """Compute inverse frequency bands.""" return 1 / ( self.base ** ( torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim ) ) def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): """Update the cached cosine and sine values.""" if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) ): self._seq_len_cached = seqlen if self.pos_idx_in_fp32: t = torch.arange(seqlen, device=device, dtype=torch.float32) t /= self.scaling_factor if self.inv_freq.dtype != torch.float32: inv_freq = self.inv_freq.to(torch.float32) else: inv_freq = self.inv_freq else: t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t /= self.scaling_factor inv_freq = self.inv_freq freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( torch.arange( seqlen, dtype=self.scale.dtype, device=self.scale.device ) - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to queries and keys. Args: q: Query tensor of shape (batch, seqlen, nheads, headdim) k: Key tensor of shape (batch, seqlen, nheads, headdim) Returns: Tuple of rotated query and key tensors """ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype) assert self._cos_cached is not None assert self._sin_cached is not None if self.scale is None: return ( apply_rotary_emb_torch( q, self._cos_cached, self._sin_cached, self.interleaved, True, # inplace=True ), apply_rotary_emb_torch( k, self._cos_cached, self._sin_cached, self.interleaved, True, # inplace=True ), ) # type: ignore else: assert False ### Feedforward Network Components def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: """Compute corrected dimension for SwiGLU.""" return int(((expansion_ratio * d_model) + 255) // 256 * 256) class SwiGLU(nn.Module): """SwiGLU activation function.""" def __init__(self): super(SwiGLU, self).__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return F.silu(x1) * x2 def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential: """Create SwiGLU feedforward network with layer normalization.""" return nn.Sequential( nn.LayerNorm(d_model), nn.Linear( d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False ), SwiGLU(), nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False), ) ### Attention class MultiHeadAttention(nn.Module): """Multi-head attention with rotary embeddings. Args: d_model: Model dimension n_heads: Number of attention heads """ def __init__(self, d_model: int, n_heads: int): super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_head = self.d_model // self.n_heads self.layernorm_qkv = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False) ) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.q_ln = nn.LayerNorm(d_model, bias=False) self.k_ln = nn.LayerNorm(d_model, bias=False) self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) self.rotary = RotaryEmbedding(d_model // n_heads) def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to query and key.""" q = q.unflatten(-1, (self.n_heads, self.d_head)) k = k.unflatten(-1, (self.n_heads, self.d_head)) q, k = self.rotary(q, k) q = q.flatten(-2, -1) k = k.flatten(-2, -1) return q, k def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: x: Input tensor attention_mask: Optional attention mask output_attentions: Whether to return attention weights Returns: Output tensor after self attention, and optionally attention weights """ attn_weights = None qkv_BLD3 = self.layernorm_qkv(x) query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) query_BLD, key_BLD = ( self.q_ln(query_BLD).to(query_BLD.dtype), self.k_ln(key_BLD).to(query_BLD.dtype), ) query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD)) if output_attentions: # Manual attention computation L, S = query_BLD.size(-2), key_BLD.size(-2) scale = 1 / math.sqrt(query_BLD.size(-1)) attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device) if attention_mask is not None: if attention_mask.dtype == torch.bool: attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf')) else: attn_bias += attention_mask attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale attn_weights += attn_bias attn_weights = F.softmax(attn_weights, dim=-1) context_BHLD = torch.matmul(attn_weights, value_BHLD) else: context_BHLD = F.scaled_dot_product_attention( query_BHLD, key_BHLD, value_BHLD, attention_mask ) context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)") output = self.out_proj(context_BLD) return output, attn_weights ### Regression Head def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module: """Create a regression head with optional hidden dimension. Args: d_model: Input dimension output_dim: Output dimension hidden_dim: Optional hidden dimension (defaults to d_model) """ hidden_dim = hidden_dim if hidden_dim is not None else d_model return nn.Sequential( nn.Linear(d_model, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, output_dim), ) ### Transformer Block class UnifiedTransformerBlock(nn.Module): """Transformer block with attention and feedforward layers. Args: d_model: Model dimension n_heads: Number of attention heads residue_scaling_factor: Factor for scaling residual connections expansion_ratio: Expansion ratio for feedforward network """ def __init__( self, d_model: int, n_heads: int, residue_scaling_factor: float = 1, expansion_ratio: float = 8 / 3, dropout: float = 0.0, ): super().__init__() self.attn = MultiHeadAttention(d_model, n_heads) self.ffn = swiglu_ln_ffn(d_model, expansion_ratio) self.scaling_factor = residue_scaling_factor self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: x: Input tensor attention_mask: Optional attention mask output_attentions: Whether to return attention weights Returns: Output tensor after transformer block, and optionally attention weights """ attn_output, attn_weights = self.attn(x, attention_mask, output_attentions) x = x + self.dropout(attn_output) / self.scaling_factor x = x + self.dropout(self.ffn(x)) / self.scaling_factor return x, attn_weights ### Model Outputs @dataclass class TransformerOutput(ModelOutput): """Output type for transformer encoder.""" last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor]] = None attentions: Optional[Tuple[torch.Tensor]] = None @dataclass class ESMplusplusOutput(ModelOutput): """Output type for ESM++ models.""" loss: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor]] = None attentions: Optional[Tuple[torch.Tensor]] = None ### Transformer Stack class TransformerStack(nn.Module): """Stack of transformer blocks. Args: d_model: Model dimension n_heads: Number of attention heads n_layers: Number of transformer layers dropout: Dropout rate """ def __init__( self, d_model: int, n_heads: int, n_layers: int, dropout: float = 0.0, ): super().__init__() self.blocks = nn.ModuleList( [ UnifiedTransformerBlock( d_model, n_heads, residue_scaling_factor=math.sqrt(n_layers / 36), dropout=dropout, ) for i in range(n_layers) ] ) self.norm = nn.LayerNorm(d_model, bias=False) self.gradient_checkpointing = False def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False, output_attentions: bool = False, ) -> TransformerOutput: """ Args: x: Input tensor attention_mask: Optional attention mask output_hidden_states: Whether to return all hidden states output_attentions: Whether to return attention weights Returns: TransformerOutput containing last hidden state and optionally all hidden states and attention weights """ batch_size, seq_len, _ = x.shape hidden_states = () if output_hidden_states else None attentions = () if output_attentions else None if attention_mask is not None: attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool() for block in self.blocks: if self.gradient_checkpointing and self.training: x, attn_weights = self._gradient_checkpointing_func( block.__call__, x, attention_mask, output_attentions, ) else: x, attn_weights = block(x, attention_mask, output_attentions) if attentions is not None: attentions += (attn_weights,) if output_hidden_states: assert hidden_states is not None hidden_states += (x,) return TransformerOutput( last_hidden_state=self.norm(x), hidden_states=hidden_states, attentions=attentions ) ### Support for embedding datasets with low code class Pooler: def __init__(self, pooling_types: List[str]): self.pooling_types = pooling_types self.pooling_options = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'min': self.min_pooling, 'norm': self.norm_pooling, 'prod': self.prod_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, } def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.max(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).max(dim=1).values def min_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.min(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).min(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def prod_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) length = emb.shape[1] if attention_mask is None: return emb.prod(dim=1) / length else: attention_mask = attention_mask.unsqueeze(-1) return ((emb * attention_mask).prod(dim=1) / attention_mask.sum(dim=1)) / length def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.median(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).median(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.std(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).std(dim=1) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.var(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).var(dim=1) def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) return emb[:, 0, :] def __call__(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # [mean, max] final_emb = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb, attention_mask)) # (b, d) return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) class ProteinDataset(TorchDataset): """Simple dataset for protein sequences.""" def __init__(self, sequences: list[str]): self.sequences = sequences def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> str: return self.sequences[idx] def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]: def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]: """Collate function for batching sequences.""" return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8) return _collate_fn class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> set[str]: """Read sequences from SQLite database.""" import sqlite3 sequences = [] with sqlite3.connect(db_path) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") while True: row = c.fetchone() if row is None: break sequences.append(row[0]) return set(sequences) def embed_dataset( self, sequences: List[str], tokenizer: PreTrainedTokenizerBase, batch_size: int = 2, max_len: int = 512, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], num_workers: int = 0, sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', ) -> Optional[dict[str, torch.Tensor]]: """Embed a dataset of protein sequences. Args: sequences: List of protein sequences batch_size: Batch size for processing max_len: Maximum sequence length full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False) pooling_type: Type of pooling ('mean' or 'cls') num_workers: Number of workers for data loading, 0 for the main process sql: Whether to store embeddings in SQLite database - will be stored in float32 sql_db_path: Path to SQLite database Returns: Dictionary mapping sequences to embeddings, or None if sql=True Note: - If sql=True, embeddings can only be stored in float32 - sql is ideal if you need to stream a very large dataset for training in real-time - save=True is ideal if you can store the entire embedding dictionary in RAM - sql will be used if it is True and save is True or False - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences - Sequences will be truncated to max_len and sorted by length in descending order for faster processing Example: >>> embedder = EmbeddingMixin() >>> embedding_dict = embedder.embed_dataset( sequences=[ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences ], batch_size=2, # adjust for your GPU memory max_len=512, # adjust for your needs full_embeddings=False, # if True, no pooling is performed embed_dtype=torch.float32, # cast to what dtype you want pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets sql=False, # if True, embeddings will be stored in SQLite database sql_db_path='embeddings.db', save=True, # if True, embeddings will be saved as a .pth file save_path='embeddings.pth', ) >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql """ sequences = list(set([seq[:max_len] for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) collate_fn = build_collator(tokenizer) device = self.device pooler = Pooler(pooling_types) if not full_embeddings else None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings return residue_embeddings else: return pooler(residue_embeddings, attention_mask) if sql: import sqlite3 conn = sqlite3.connect(sql_db_path) c = conn.cursor() c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: dataset = ProteinDataset(to_embed) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) with torch.no_grad(): for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device) residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32 embeddings = get_embeddings(residue_embeddings, attention_mask).cpu() for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()] c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes())) if (i + 1) % 100 == 0: conn.commit() conn.commit() conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: dataset = ProteinDataset(to_embed) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) with torch.no_grad(): for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device) residue_embeddings = self._embed(input_ids, attention_mask) embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu() for seq, emb in zip(seqs, embeddings): embeddings_dict[seq] = emb if save: torch.save(embeddings_dict, save_path) return embeddings_dict class PreTrainedESMplusplusModel(PreTrainedModel): """ init weights for ESM++ models """ config_class = ESMplusplusConfig base_model_prefix = "esm++" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): if module.bias is not None: module.bias.data.zero_() module.weight.data.fill_(1.0) @classmethod def from_pretrained_esm(cls, model_name: str): """Load a pretrained ESM++ model.""" if '300' in model_name: return ESMplusplus_300M() elif '600' in model_name: return ESMplusplus_600M() else: raise ValueError(f"Invalid model name: {model_name}") ### ESM++ Models class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin): """ ESM++ model. transformer model with no heads """ config_class = ESMplusplusConfig def __init__(self, config: ESMplusplusConfig, **kwargs): super(PreTrainedESMplusplusModel, self).__init__(config, **kwargs) self.config = config self.vocab_size = config.vocab_size self.embed = nn.Embedding(self.vocab_size, config.hidden_size) self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout) self.tokenizer = EsmSequenceTokenizer() self.init_weights() def get_input_embeddings(self): return self.embed def set_input_embeddings(self, value): self.embed = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.embed(input_ids) return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # to play nice with HF adjacent packages ) -> TransformerOutput: """Forward pass for masked language modeling. Args: input_ids: Input token IDs attention_mask: Attention mask inputs_embeds: Optional precomputed embeddings output_hidden_states: Whether to return all hidden states output_attentions: Whether to return attention weights Returns: TransformerOutput containing last hidden state and optionally all hidden states and attention weights """ if inputs_embeds is None: x = self.embed(input_ids) else: x = inputs_embeds return self.transformer(x, attention_mask, output_hidden_states, output_attentions) class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin): """ ESM++ model for masked language modeling. Implements the base ESM++ architecture with a masked language modeling head. """ config_class = ESMplusplusConfig def __init__(self, config: ESMplusplusConfig, **kwargs): super(PreTrainedESMplusplusModel, self).__init__(config, **kwargs) self.config = config self.vocab_size = config.vocab_size self.embed = nn.Embedding(self.vocab_size, config.hidden_size) self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout) self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size) self.ce_loss = nn.CrossEntropyLoss() self.tokenizer = EsmSequenceTokenizer() self.init_weights() def get_input_embeddings(self): return self.embed def set_input_embeddings(self, value): self.embed = value def get_output_embeddings(self): return self.sequence_head[-1] def set_output_embeddings(self, new_embeddings): self.sequence_head[-1] = new_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.embed(input_ids) return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # to play nice with HF adjacent packages ) -> ESMplusplusOutput: """Forward pass for masked language modeling. Args: input_ids: Input token IDs attention_mask: Attention mask inputs_embeds: Optional precomputed embeddings labels: Optional labels for masked tokens output_hidden_states: Whether to return all hidden states output_attentions: Whether to return attention weights Returns: ESMplusplusOutput containing loss, logits, hidden states and attention weights """ if inputs_embeds is None: x = self.embed(input_ids) else: x = inputs_embeds output = self.transformer(x, attention_mask, output_hidden_states, output_attentions) x = output.last_hidden_state logits = self.sequence_head(x) loss = None if labels is not None: loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1)) return ESMplusplusOutput( loss=loss, logits=logits, last_hidden_state=x, hidden_states=output.hidden_states, attentions=output.attentions, ) class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin): """ ESM++ model for sequence classification. Extends the base ESM++ model with a classification head. """ def __init__(self, config: ESMplusplusConfig, **kwargs): super(ESMplusplusForMaskedLM, self).__init__(config, **kwargs) self.config = config self.num_labels = config.num_labels self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4) # Large intermediate projections help with sequence classification tasks (*4) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.init_weights() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.embed(input_ids) return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # to play nice with HF adjacent packages ) -> ESMplusplusOutput: """Forward pass for sequence classification. Args: input_ids: Input token IDs attention_mask: Attention mask inputs_embeds: Optional precomputed embeddings labels: Optional labels for classification output_hidden_states: Whether to return all hidden states output_attentions: Whether to return attention weights Returns: ESMplusplusOutput containing loss, logits, and hidden states """ output = super().forward( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states ) x = output.last_hidden_state cls_features = x[:, 0, :] mean_features = self.mean_pooling(x, attention_mask) # we include mean pooling features to help with early convergence, the cost of this is basically zero features = torch.cat([cls_features, mean_features], dim=-1) logits = self.classifier(features) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.flatten(), labels.flatten()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return ESMplusplusOutput( loss=loss, logits=logits, last_hidden_state=x, hidden_states=output.hidden_states, ) class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin): """ ESM++ model for token classification. Extends the base ESM++ model with a token classification head. """ def __init__(self, config: ESMplusplusConfig): super(ESMplusplusForMaskedLM, self).__init__(config) self.config = config self.num_labels = config.num_labels self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4) # Large intermediate projections help with sequence classification tasks (*4) self.loss_fct = nn.CrossEntropyLoss() self.init_weights() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.embed(input_ids) return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, # to play nice with HF adjacent packages ) -> ESMplusplusOutput: """Forward pass for token classification. Args: input_ids: Input token IDs attention_mask: Attention mask inputs_embeds: Optional precomputed embeddings labels: Optional labels for token classification output_hidden_states: Whether to return all hidden states output_attentions: Whether to return attention weights Returns: ESMplusplusOutput containing loss, logits, and hidden states """ output = super().forward( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=None, output_attentions=output_attentions, output_hidden_states=output_hidden_states ) x = output.last_hidden_state logits = self.classifier(x) loss = None if labels is not None: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return ESMplusplusOutput( loss=loss, logits=logits, last_hidden_state=x, hidden_states=output.hidden_states, ) ### Loading from EvolutionaryScale @staticmethod @cache def data_root(model: str): if "INFRA_PROVIDER" in os.environ: return Path("") # Try to download from hugginface if it doesn't exist if model.startswith("esmc-300"): path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12")) elif model.startswith("esmc-600"): path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12")) else: raise ValueError(f"{model=} is an invalid model name.") return path def ESMplusplus_300M(device: torch.device | str = "cpu"): with torch.device(device): config = ESMplusplusConfig( hidden_size=960, num_attention_heads=15, num_hidden_layers=30, ) model = ESMplusplusForMaskedLM(config) state_dict = torch.load( data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth", map_location=device, ) model.load_state_dict(state_dict) return model def ESMplusplus_600M(device: torch.device | str = "cpu"): with torch.device(device): config = ESMplusplusConfig( hidden_size=1152, num_attention_heads=18, num_hidden_layers=36, ) model = ESMplusplusForMaskedLM(config) state_dict = torch.load( data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth", map_location=device, ) model.load_state_dict(state_dict) return model ### Tokenization SEQUENCE_VOCAB = [ "", "", "", "", "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", "O", ".", "-", "|", "", ] class EsmSequenceTokenizer(PreTrainedTokenizerFast): model_input_names = ["input_ids", "attention_mask"] def __init__( self, unk_token="", cls_token="", pad_token="", mask_token="", eos_token="", chain_break_token="|", **kwargs, ): all_tokens = SEQUENCE_VOCAB token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} # a character-level tokenizer is the same as BPE with no token merges bpe = BPE(token_to_id, merges=[], unk_token=unk_token) tokenizer = Tokenizer(bpe) special_tokens = [ cls_token, pad_token, mask_token, eos_token, chain_break_token, ] self.cb_token = chain_break_token additional_special_tokens = [chain_break_token] tokenizer.add_special_tokens(special_tokens) # This is where we configure the automatic addition of special tokens when we call # tokenizer(text, add_special_tokens=True). Note that you can also configure how two # sequences are merged if you want. tokenizer.post_processor = TemplateProcessing( # type: ignore single=" $A ", special_tokens=[ ("", tokenizer.token_to_id("")), ("", tokenizer.token_to_id("")), ], ) super().__init__( tokenizer_object=tokenizer, unk_token=unk_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs, ) # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. @property def bos_token(self): return self.cls_token @property def bos_token_id(self): return self.cls_token_id @property def chain_break_token(self): return self.cb_token @property def chain_break_token_id(self): return self.convert_tokens_to_ids(self.chain_break_token) @property def all_token_ids(self): return list(range(self.vocab_size)) @property def special_token_ids(self): return self.all_special_ids