from transformers import PreTrainedModel, PretrainedConfig from torch import nn from esm.pretrained import ESM3_sm_open_v0 import torch class StabilityPredictionConfig(PretrainedConfig): def __init__(self, embed_dim=1536, *args, **kwargs): super().__init__(*args, embed_dim=1536, **kwargs) class SingleMutationPooler(nn.Module): def __init__(self, embed_dim=1536): super().__init__() self.wt_weight = nn.Parameter(torch.ones((1, embed_dim)), requires_grad=True) self.mut_weight = nn.Parameter(-1 * torch.ones((1, embed_dim)), requires_grad=True) self.norm = nn.LayerNorm(embed_dim, bias=False) def forward(self, wt_embedding, mut_embedding, positions): embed_shape = wt_embedding.shape[-1] positions = positions.view(-1, 1).unsqueeze(2).repeat(1, 1, embed_shape) + 1 wt_residues = torch.gather(wt_embedding, 1, positions).squeeze(1) mut_residues = torch.gather(mut_embedding, 1, positions).squeeze(1) wt_residues = wt_residues * self.wt_weight mut_residues = mut_residues * self.mut_weight return self.norm(wt_residues + mut_residues) class StabilityPrediction(PreTrainedModel): config_class = StabilityPredictionConfig def __init__(self, config=StabilityPredictionConfig()): super().__init__(config=config) self.backbone = ESM3_sm_open_v0(getattr(config, "device", "cpu")) self.pooler = SingleMutationPooler() self.regressor = nn.Linear(config.embed_dim, 1) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.regressor.weight, -0.01, 0.01) nn.init.zeros_(self.regressor.bias) def compute_loss(self, logits, labels): if labels is None: return return F.mse_loss(logits, labels) def forward(self, wt_input_ids, mut_input_ids, positions, labels=None): wt_embeddings = self.backbone(sequence_tokens=wt_input_ids).embeddings mut_embeddings = self.backbone(sequence_tokens=mut_input_ids).embeddings aggregated_embeddings = self.pooler(wt_embeddings, mut_embeddings, positions) logits = self.regressor(aggregated_embeddings) loss = self.compute_loss(logits, labels) return { "loss": loss, "logits": logits, }