Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..modules import ( | |
TransformerLayer, | |
LearnedPositionalEmbedding, | |
SinusoidalPositionalEmbedding, | |
RobertaLMHead, | |
ESM1bLayerNorm, | |
ContactPredictionHead, | |
) | |
class ProteinBertModel(nn.Module): | |
def add_args(cls, parser): | |
parser.add_argument( | |
"--num_layers", default=36, type=int, metavar="N", help="number of layers" | |
) | |
parser.add_argument( | |
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension" | |
) | |
parser.add_argument( | |
"--logit_bias", action="store_true", help="whether to apply bias to logits" | |
) | |
parser.add_argument( | |
"--ffn_embed_dim", | |
default=5120, | |
type=int, | |
metavar="N", | |
help="embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--attention_heads", | |
default=20, | |
type=int, | |
metavar="N", | |
help="number of attention heads", | |
) | |
def __init__(self, args, alphabet): | |
super().__init__() | |
self.args = args | |
self.alphabet_size = len(alphabet) | |
self.padding_idx = alphabet.padding_idx | |
self.mask_idx = alphabet.mask_idx | |
self.cls_idx = alphabet.cls_idx | |
self.eos_idx = alphabet.eos_idx | |
self.prepend_bos = alphabet.prepend_bos | |
self.append_eos = alphabet.append_eos | |
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False) | |
if self.args.arch == "roberta_large": | |
self.model_version = "ESM-1b" | |
self._init_submodules_esm1b() | |
else: | |
self.model_version = "ESM-1" | |
self._init_submodules_esm1() | |
def _init_submodules_common(self): | |
self.embed_tokens = nn.Embedding( | |
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx | |
) | |
self.layers = nn.ModuleList( | |
[ | |
TransformerLayer( | |
self.args.embed_dim, | |
self.args.ffn_embed_dim, | |
self.args.attention_heads, | |
add_bias_kv=(self.model_version != "ESM-1b"), | |
use_esm1b_layer_norm=(self.model_version == "ESM-1b"), | |
) | |
for _ in range(self.args.layers) | |
] | |
) | |
self.contact_head = ContactPredictionHead( | |
self.args.layers * self.args.attention_heads, | |
self.prepend_bos, | |
self.append_eos, | |
eos_idx=self.eos_idx, | |
) | |
def _init_submodules_esm1b(self): | |
self._init_submodules_common() | |
self.embed_scale = 1 | |
self.embed_positions = LearnedPositionalEmbedding( | |
self.args.max_positions, self.args.embed_dim, self.padding_idx | |
) | |
self.emb_layer_norm_before = ( | |
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None | |
) | |
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) | |
self.lm_head = RobertaLMHead( | |
embed_dim=self.args.embed_dim, | |
output_dim=self.alphabet_size, | |
weight=self.embed_tokens.weight, | |
) | |
def _init_submodules_esm1(self): | |
self._init_submodules_common() | |
self.embed_scale = math.sqrt(self.args.embed_dim) | |
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx) | |
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim))) | |
self.embed_out_bias = None | |
if self.args.final_bias: | |
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size)) | |
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): | |
if return_contacts: | |
need_head_weights = True | |
assert tokens.ndim == 2 | |
padding_mask = tokens.eq(self.padding_idx) # B, T | |
x = self.embed_scale * self.embed_tokens(tokens) | |
if getattr(self.args, "token_dropout", False): | |
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) | |
# x: B x T x C | |
mask_ratio_train = 0.15 * 0.8 | |
src_lengths = (~padding_mask).sum(-1) | |
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths | |
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] | |
x = x + self.embed_positions(tokens) | |
if self.model_version == "ESM-1b": | |
if self.emb_layer_norm_before: | |
x = self.emb_layer_norm_before(x) | |
if padding_mask is not None: | |
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
repr_layers = set(repr_layers) | |
hidden_representations = {} | |
if 0 in repr_layers: | |
hidden_representations[0] = x | |
if need_head_weights: | |
attn_weights = [] | |
# (B, T, E) => (T, B, E) | |
x = x.transpose(0, 1) | |
if not padding_mask.any(): | |
padding_mask = None | |
for layer_idx, layer in enumerate(self.layers): | |
x, attn = layer( | |
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights | |
) | |
if (layer_idx + 1) in repr_layers: | |
hidden_representations[layer_idx + 1] = x.transpose(0, 1) | |
if need_head_weights: | |
# (H, B, T, T) => (B, H, T, T) | |
attn_weights.append(attn.transpose(1, 0)) | |
if self.model_version == "ESM-1b": | |
x = self.emb_layer_norm_after(x) | |
x = x.transpose(0, 1) # (T, B, E) => (B, T, E) | |
# last hidden representation should have layer norm applied | |
if (layer_idx + 1) in repr_layers: | |
hidden_representations[layer_idx + 1] = x | |
x = self.lm_head(x) | |
else: | |
x = F.linear(x, self.embed_out, bias=self.embed_out_bias) | |
x = x.transpose(0, 1) # (T, B, E) => (B, T, E) | |
result = {"logits": x, "representations": hidden_representations} | |
if need_head_weights: | |
# attentions: B x L x H x T x T | |
attentions = torch.stack(attn_weights, 1) | |
if self.model_version == "ESM-1": | |
# ESM-1 models have an additional null-token for attention, which we remove | |
attentions = attentions[..., :-1] | |
if padding_mask is not None: | |
attention_mask = 1 - padding_mask.type_as(attentions) | |
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) | |
attentions = attentions * attention_mask[:, None, None, :, :] | |
result["attentions"] = attentions | |
if return_contacts: | |
contacts = self.contact_head(tokens, attentions) | |
result["contacts"] = contacts | |
return result | |
def predict_contacts(self, tokens): | |
return self(tokens, return_contacts=True)["contacts"] | |
def num_layers(self): | |
return self.args.layers | |