lhallee commited on
Commit
cd0641b
·
verified ·
1 Parent(s): cfe6f59

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +219 -101
modeling_esm_plusplus.py CHANGED
@@ -16,15 +16,16 @@ import torch.nn.functional as F
16
  from dataclasses import dataclass
17
  from functools import cache, partial
18
  from pathlib import Path
19
- from typing import Optional, Tuple, Union
20
  from einops import rearrange, repeat
21
  from huggingface_hub import snapshot_download
22
  from tokenizers import Tokenizer
23
  from tokenizers.models import BPE
24
  from tokenizers.processors import TemplateProcessing
25
- from torch.utils.data import Dataset, DataLoader
 
26
  from tqdm.auto import tqdm
27
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
28
  from transformers.modeling_outputs import ModelOutput
29
 
30
 
@@ -501,8 +502,90 @@ class TransformerStack(nn.Module):
501
  )
502
 
503
 
504
- ### Dataset for Embedding
505
- class ProteinDataset(Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  """Simple dataset for protein sequences."""
507
  def __init__(self, sequences: list[str]):
508
  self.sequences = sequences
@@ -514,68 +597,22 @@ class ProteinDataset(Dataset):
514
  return self.sequences[idx]
515
 
516
 
517
- class PreTrainedESMplusplusModel(PreTrainedModel):
518
- """
519
- init weights for ESM++ models
520
- """
521
- config_class = ESMplusplusConfig
522
- base_model_prefix = "esm++"
523
- supports_gradient_checkpointing = True
524
 
525
- def _init_weights(self, module):
526
- """Initialize the weights"""
527
- if isinstance(module, nn.Linear):
528
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
529
- if module.bias is not None:
530
- module.bias.data.zero_()
531
- elif isinstance(module, nn.Embedding):
532
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
533
- if module.padding_idx is not None:
534
- module.weight.data[module.padding_idx].zero_()
535
- elif isinstance(module, nn.LayerNorm):
536
- if module.bias is not None:
537
- module.bias.data.zero_()
538
- module.weight.data.fill_(1.0)
539
 
540
- @classmethod
541
- def from_pretrained_esm(cls, model_name: str):
542
- """Load a pretrained ESM++ model."""
543
- if '300' in model_name:
544
- return ESMplusplus_300M()
545
- elif '600' in model_name:
546
- return ESMplusplus_600M()
547
- else:
548
- raise ValueError(f"Invalid model name: {model_name}")
549
 
550
  @property
551
  def device(self) -> torch.device:
552
  """Get the device of the model."""
553
  return next(self.parameters()).device
554
 
555
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
556
- """Apply mean pooling to sequence outputs."""
557
- if attention_mask is None:
558
- return x.mean(dim=1)
559
- else:
560
- attention_mask = attention_mask.unsqueeze(-1)
561
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
562
-
563
- def max_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
564
- """Apply max pooling to sequence outputs."""
565
- if attention_mask is None:
566
- return x.max(dim=1).values
567
- else:
568
- attention_mask = attention_mask.unsqueeze(-1)
569
- return (x * attention_mask).max(dim=1).values
570
-
571
- def cls_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
572
- """Apply cls pooling to sequence outputs."""
573
- return x[:, 0, :]
574
-
575
- def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
576
- """Collate function for batching sequences."""
577
- return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
578
-
579
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
580
  """Read sequences from SQLite database."""
581
  import sqlite3
@@ -592,15 +629,18 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
592
 
593
  def embed_dataset(
594
  self,
595
- sequences: list[str],
 
596
  batch_size: int = 2,
597
  max_len: int = 512,
598
  full_embeddings: bool = False,
599
- full_precision: bool = False,
600
- pooling_type: str = 'mean',
601
  num_workers: int = 0,
602
  sql: bool = False,
 
603
  sql_db_path: str = 'embeddings.db',
 
604
  ) -> Optional[dict[str, torch.Tensor]]:
605
  """Embed a dataset of protein sequences.
606
 
@@ -609,7 +649,6 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
609
  batch_size: Batch size for processing
610
  max_len: Maximum sequence length
611
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
612
- full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
613
  pooling_type: Type of pooling ('mean' or 'cls')
614
  num_workers: Number of workers for data loading, 0 for the main process
615
  sql: Whether to store embeddings in SQLite database - will be stored in float32
@@ -617,23 +656,46 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
617
 
618
  Returns:
619
  Dictionary mapping sequences to embeddings, or None if sql=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  """
621
  sequences = list(set([seq[:max_len] for seq in sequences]))
 
 
622
  device = self.device
 
623
 
624
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
625
- if full_embeddings:
626
  return residue_embeddings
627
- elif pooling_type == 'mean':
628
- return self.mean_pooling(residue_embeddings, attention_mask)
629
- elif pooling_type == 'max':
630
- return self.max_pooling(residue_embeddings, attention_mask)
631
- elif pooling_type == 'cls':
632
- return self.cls_pooling(residue_embeddings, attention_mask)
633
  else:
634
- raise ValueError(f"Invalid pooling type: {pooling_type}")
635
 
636
- sequences = list(set([seq[:max_len] for seq in sequences]))
637
  if sql:
638
  import sqlite3
639
  conn = sqlite3.connect(sql_db_path)
@@ -644,17 +706,14 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
644
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
645
  print(f"Embedding {len(to_embed)} new sequences")
646
  if len(to_embed) > 0:
647
- to_embed = sorted(to_embed, key=len, reverse=True)
648
  dataset = ProteinDataset(to_embed)
649
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
650
  with torch.no_grad():
651
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
652
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
653
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
654
- x = self.embed(input_ids)
655
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
656
- embeddings = get_embeddings(residue_embeddings, attention_mask)
657
-
658
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
659
  if full_embeddings:
660
  emb = emb[mask.bool()]
@@ -669,32 +728,75 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
669
  return None
670
 
671
  embeddings_dict = {}
672
- sequences = sorted(sequences, key=len, reverse=True)
673
- dataset = ProteinDataset(sequences)
674
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
675
- with torch.no_grad():
676
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
677
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
678
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
679
- x = self.embed(input_ids)
680
- residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach()
681
- if full_precision:
682
- residue_embeddings = residue_embeddings.float()
683
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
684
- for seq, emb in zip(seqs, embeddings):
685
- embeddings_dict[seq] = emb
686
-
 
 
 
 
 
 
 
 
 
687
  return embeddings_dict
688
 
689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
  ### ESM++ Models
691
- class ESMplusplusModel(PreTrainedESMplusplusModel):
692
  """
693
  ESM++ model. transformer model with no heads
694
  """
695
  config_class = ESMplusplusConfig
696
  def __init__(self, config: ESMplusplusConfig, **kwargs):
697
- super().__init__(config, **kwargs)
698
  self.config = config
699
  self.vocab_size = config.vocab_size
700
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -708,6 +810,10 @@ class ESMplusplusModel(PreTrainedESMplusplusModel):
708
  def set_input_embeddings(self, value):
709
  self.embed = value
710
 
 
 
 
 
711
  def forward(
712
  self,
713
  input_ids: Optional[torch.Tensor] = None,
@@ -736,14 +842,14 @@ class ESMplusplusModel(PreTrainedESMplusplusModel):
736
  return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
737
 
738
 
739
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
740
  """
741
  ESM++ model for masked language modeling.
742
  Implements the base ESM++ architecture with a masked language modeling head.
743
  """
744
  config_class = ESMplusplusConfig
745
  def __init__(self, config: ESMplusplusConfig, **kwargs):
746
- super().__init__(config, **kwargs)
747
  self.config = config
748
  self.vocab_size = config.vocab_size
749
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
@@ -765,6 +871,10 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
765
  def set_output_embeddings(self, new_embeddings):
766
  self.sequence_head[-1] = new_embeddings
767
 
 
 
 
 
768
  def forward(
769
  self,
770
  input_ids: Optional[torch.Tensor] = None,
@@ -807,13 +917,13 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
807
  )
808
 
809
 
810
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
811
  """
812
  ESM++ model for sequence classification.
813
  Extends the base ESM++ model with a classification head.
814
  """
815
  def __init__(self, config: ESMplusplusConfig, **kwargs):
816
- super().__init__(config, **kwargs)
817
  self.config = config
818
  self.num_labels = config.num_labels
819
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
@@ -823,6 +933,10 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
823
  self.bce = nn.BCEWithLogitsLoss()
824
  self.init_weights()
825
 
 
 
 
 
826
  def forward(
827
  self,
828
  input_ids: Optional[torch.Tensor] = None,
@@ -888,13 +1002,13 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
888
  )
889
 
890
 
891
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
892
  """
893
  ESM++ model for token classification.
894
  Extends the base ESM++ model with a token classification head.
895
  """
896
  def __init__(self, config: ESMplusplusConfig):
897
- super().__init__(config)
898
  self.config = config
899
  self.num_labels = config.num_labels
900
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
@@ -902,6 +1016,10 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
902
  self.loss_fct = nn.CrossEntropyLoss()
903
  self.init_weights()
904
 
 
 
 
 
905
  def forward(
906
  self,
907
  input_ids: Optional[torch.Tensor] = None,
 
16
  from dataclasses import dataclass
17
  from functools import cache, partial
18
  from pathlib import Path
19
+ from typing import Optional, Tuple, Union, List, Callable, Dict
20
  from einops import rearrange, repeat
21
  from huggingface_hub import snapshot_download
22
  from tokenizers import Tokenizer
23
  from tokenizers.models import BPE
24
  from tokenizers.processors import TemplateProcessing
25
+ from torch.utils.data import Dataset as TorchDataset
26
+ from torch.utils.data import DataLoader
27
  from tqdm.auto import tqdm
28
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PreTrainedTokenizerBase, PretrainedConfig
29
  from transformers.modeling_outputs import ModelOutput
30
 
31
 
 
502
  )
503
 
504
 
505
+ ### Support for embedding datasets with low code
506
+ class Pooler:
507
+ def __init__(self, pooling_types: List[str]):
508
+ self.pooling_types = pooling_types
509
+ self.pooling_options = {
510
+ 'mean': self.mean_pooling,
511
+ 'max': self.max_pooling,
512
+ 'min': self.min_pooling,
513
+ 'norm': self.norm_pooling,
514
+ 'prod': self.prod_pooling,
515
+ 'median': self.median_pooling,
516
+ 'std': self.std_pooling,
517
+ 'var': self.var_pooling,
518
+ 'cls': self.cls_pooling,
519
+ }
520
+
521
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
522
+ if attention_mask is None:
523
+ return emb.mean(dim=1)
524
+ else:
525
+ attention_mask = attention_mask.unsqueeze(-1)
526
+ return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
527
+
528
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
529
+ if attention_mask is None:
530
+ return emb.max(dim=1).values
531
+ else:
532
+ attention_mask = attention_mask.unsqueeze(-1)
533
+ return (emb * attention_mask).max(dim=1).values
534
+
535
+ def min_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
536
+ if attention_mask is None:
537
+ return emb.min(dim=1).values
538
+ else:
539
+ attention_mask = attention_mask.unsqueeze(-1)
540
+ return (emb * attention_mask).min(dim=1).values
541
+
542
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
543
+ if attention_mask is None:
544
+ return emb.norm(dim=1, p=2)
545
+ else:
546
+ attention_mask = attention_mask.unsqueeze(-1)
547
+ return (emb * attention_mask).norm(dim=1, p=2)
548
+
549
+ def prod_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
550
+ length = emb.shape[1]
551
+ if attention_mask is None:
552
+ return emb.prod(dim=1) / length
553
+ else:
554
+ attention_mask = attention_mask.unsqueeze(-1)
555
+ return ((emb * attention_mask).prod(dim=1) / attention_mask.sum(dim=1)) / length
556
+
557
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
558
+ if attention_mask is None:
559
+ return emb.median(dim=1).values
560
+ else:
561
+ attention_mask = attention_mask.unsqueeze(-1)
562
+ return (emb * attention_mask).median(dim=1).values
563
+
564
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
565
+ if attention_mask is None:
566
+ return emb.std(dim=1)
567
+ else:
568
+ attention_mask = attention_mask.unsqueeze(-1)
569
+ return (emb * attention_mask).std(dim=1)
570
+
571
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
572
+ if attention_mask is None:
573
+ return emb.var(dim=1)
574
+ else:
575
+ attention_mask = attention_mask.unsqueeze(-1)
576
+ return (emb * attention_mask).var(dim=1)
577
+
578
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
579
+ return emb[:, 0, :]
580
+
581
+ def __call__(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # [mean, max]
582
+ final_emb = []
583
+ for pooling_type in self.pooling_types:
584
+ final_emb.append(self.pooling_options[pooling_type](emb, attention_mask)) # (b, d)
585
+ return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
586
+
587
+
588
+ class ProteinDataset(TorchDataset):
589
  """Simple dataset for protein sequences."""
590
  def __init__(self, sequences: list[str]):
591
  self.sequences = sequences
 
597
  return self.sequences[idx]
598
 
599
 
600
+ def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
601
+ def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
602
+ """Collate function for batching sequences."""
603
+ return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
604
+ return _collate_fn
 
 
605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
+ class EmbeddingMixin:
608
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
609
+ raise NotImplementedError
 
 
 
 
 
 
610
 
611
  @property
612
  def device(self) -> torch.device:
613
  """Get the device of the model."""
614
  return next(self.parameters()).device
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
617
  """Read sequences from SQLite database."""
618
  import sqlite3
 
629
 
630
  def embed_dataset(
631
  self,
632
+ sequences: List[str],
633
+ tokenizer: PreTrainedTokenizerBase,
634
  batch_size: int = 2,
635
  max_len: int = 512,
636
  full_embeddings: bool = False,
637
+ embed_dtype: torch.dtype = torch.float32,
638
+ pooling_types: List[str] = ['mean'],
639
  num_workers: int = 0,
640
  sql: bool = False,
641
+ save: bool = True,
642
  sql_db_path: str = 'embeddings.db',
643
+ save_path: str = 'embeddings.pth',
644
  ) -> Optional[dict[str, torch.Tensor]]:
645
  """Embed a dataset of protein sequences.
646
 
 
649
  batch_size: Batch size for processing
650
  max_len: Maximum sequence length
651
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
 
652
  pooling_type: Type of pooling ('mean' or 'cls')
653
  num_workers: Number of workers for data loading, 0 for the main process
654
  sql: Whether to store embeddings in SQLite database - will be stored in float32
 
656
 
657
  Returns:
658
  Dictionary mapping sequences to embeddings, or None if sql=True
659
+
660
+ Note:
661
+ - If sql=True, embeddings can only be stored in float32
662
+ - sql is ideal if you need to stream a very large dataset for training in real-time
663
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
664
+ - sql will be used if it is True and save is True or False
665
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
666
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
667
+
668
+ Example:
669
+ >>> embedder = EmbeddingMixin()
670
+ >>> embedding_dict = embedder.embed_dataset(
671
+ sequences=[
672
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
673
+ ],
674
+ batch_size=2, # adjust for your GPU memory
675
+ max_len=512, # adjust for your needs
676
+ full_embeddings=False, # if True, no pooling is performed
677
+ embed_dtype=torch.float32, # cast to what dtype you want
678
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
679
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
680
+ sql=False, # if True, embeddings will be stored in SQLite database
681
+ sql_db_path='embeddings.db',
682
+ save=True, # if True, embeddings will be saved as a .pth file
683
+ save_path='embeddings.pth',
684
+ )
685
+ >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
686
  """
687
  sequences = list(set([seq[:max_len] for seq in sequences]))
688
+ sequences = sorted(sequences, key=len, reverse=True)
689
+ collate_fn = build_collator(tokenizer)
690
  device = self.device
691
+ pooler = Pooler(pooling_types) if not full_embeddings else None
692
 
693
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
694
+ if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
695
  return residue_embeddings
 
 
 
 
 
 
696
  else:
697
+ return pooler(residue_embeddings, attention_mask)
698
 
 
699
  if sql:
700
  import sqlite3
701
  conn = sqlite3.connect(sql_db_path)
 
706
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
707
  print(f"Embedding {len(to_embed)} new sequences")
708
  if len(to_embed) > 0:
 
709
  dataset = ProteinDataset(to_embed)
710
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
711
  with torch.no_grad():
712
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
713
  seqs = to_embed[i * batch_size:(i + 1) * batch_size]
714
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
715
+ residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
716
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
 
 
717
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
718
  if full_embeddings:
719
  emb = emb[mask.bool()]
 
728
  return None
729
 
730
  embeddings_dict = {}
731
+ if os.path.exists(save_path):
732
+ embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
733
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
734
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
735
+ print(f"Embedding {len(to_embed)} new sequences")
736
+ else:
737
+ to_embed = sequences
738
+ print(f"Embedding {len(to_embed)} new sequences")
739
+
740
+ if len(to_embed) > 0:
741
+ dataset = ProteinDataset(to_embed)
742
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
743
+ with torch.no_grad():
744
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
745
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
746
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
747
+ residue_embeddings = self._embed(input_ids, attention_mask)
748
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
749
+ for seq, emb in zip(seqs, embeddings):
750
+ embeddings_dict[seq] = emb
751
+
752
+ if save:
753
+ torch.save(embeddings_dict, save_path)
754
+
755
  return embeddings_dict
756
 
757
 
758
+ class PreTrainedESMplusplusModel(PreTrainedModel):
759
+ """
760
+ init weights for ESM++ models
761
+ """
762
+ config_class = ESMplusplusConfig
763
+ base_model_prefix = "esm++"
764
+ supports_gradient_checkpointing = True
765
+
766
+ def _init_weights(self, module):
767
+ """Initialize the weights"""
768
+ if isinstance(module, nn.Linear):
769
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
770
+ if module.bias is not None:
771
+ module.bias.data.zero_()
772
+ elif isinstance(module, nn.Embedding):
773
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
774
+ if module.padding_idx is not None:
775
+ module.weight.data[module.padding_idx].zero_()
776
+ elif isinstance(module, nn.LayerNorm):
777
+ if module.bias is not None:
778
+ module.bias.data.zero_()
779
+ module.weight.data.fill_(1.0)
780
+
781
+ @classmethod
782
+ def from_pretrained_esm(cls, model_name: str):
783
+ """Load a pretrained ESM++ model."""
784
+ if '300' in model_name:
785
+ return ESMplusplus_300M()
786
+ elif '600' in model_name:
787
+ return ESMplusplus_600M()
788
+ else:
789
+ raise ValueError(f"Invalid model name: {model_name}")
790
+
791
+
792
  ### ESM++ Models
793
+ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
794
  """
795
  ESM++ model. transformer model with no heads
796
  """
797
  config_class = ESMplusplusConfig
798
  def __init__(self, config: ESMplusplusConfig, **kwargs):
799
+ super(PreTrainedESMplusplusModel, self).__init__(config, **kwargs)
800
  self.config = config
801
  self.vocab_size = config.vocab_size
802
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
810
  def set_input_embeddings(self, value):
811
  self.embed = value
812
 
813
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
814
+ x = self.embed(input_ids)
815
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
816
+
817
  def forward(
818
  self,
819
  input_ids: Optional[torch.Tensor] = None,
 
842
  return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
843
 
844
 
845
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
846
  """
847
  ESM++ model for masked language modeling.
848
  Implements the base ESM++ architecture with a masked language modeling head.
849
  """
850
  config_class = ESMplusplusConfig
851
  def __init__(self, config: ESMplusplusConfig, **kwargs):
852
+ super(PreTrainedESMplusplusModel, self).__init__(config, **kwargs)
853
  self.config = config
854
  self.vocab_size = config.vocab_size
855
  self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
 
871
  def set_output_embeddings(self, new_embeddings):
872
  self.sequence_head[-1] = new_embeddings
873
 
874
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
875
+ x = self.embed(input_ids)
876
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
877
+
878
  def forward(
879
  self,
880
  input_ids: Optional[torch.Tensor] = None,
 
917
  )
918
 
919
 
920
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
921
  """
922
  ESM++ model for sequence classification.
923
  Extends the base ESM++ model with a classification head.
924
  """
925
  def __init__(self, config: ESMplusplusConfig, **kwargs):
926
+ super(ESMplusplusForMaskedLM, self).__init__(config, **kwargs)
927
  self.config = config
928
  self.num_labels = config.num_labels
929
  self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
 
933
  self.bce = nn.BCEWithLogitsLoss()
934
  self.init_weights()
935
 
936
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
937
+ x = self.embed(input_ids)
938
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
939
+
940
  def forward(
941
  self,
942
  input_ids: Optional[torch.Tensor] = None,
 
1002
  )
1003
 
1004
 
1005
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
1006
  """
1007
  ESM++ model for token classification.
1008
  Extends the base ESM++ model with a token classification head.
1009
  """
1010
  def __init__(self, config: ESMplusplusConfig):
1011
+ super(ESMplusplusForMaskedLM, self).__init__(config)
1012
  self.config = config
1013
  self.num_labels = config.num_labels
1014
  self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
 
1016
  self.loss_fct = nn.CrossEntropyLoss()
1017
  self.init_weights()
1018
 
1019
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1020
+ x = self.embed(input_ids)
1021
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
1022
+
1023
  def forward(
1024
  self,
1025
  input_ids: Optional[torch.Tensor] = None,