EdwardoSunny's picture
finished
85ab89d
raw
history blame
7.34 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..modules import (
TransformerLayer,
LearnedPositionalEmbedding,
SinusoidalPositionalEmbedding,
RobertaLMHead,
ESM1bLayerNorm,
ContactPredictionHead,
)
class ProteinBertModel(nn.Module):
@classmethod
def add_args(cls, parser):
parser.add_argument(
"--num_layers", default=36, type=int, metavar="N", help="number of layers"
)
parser.add_argument(
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension"
)
parser.add_argument(
"--logit_bias", action="store_true", help="whether to apply bias to logits"
)
parser.add_argument(
"--ffn_embed_dim",
default=5120,
type=int,
metavar="N",
help="embedding dimension for FFN",
)
parser.add_argument(
"--attention_heads",
default=20,
type=int,
metavar="N",
help="number of attention heads",
)
def __init__(self, args, alphabet):
super().__init__()
self.args = args
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False)
if self.args.arch == "roberta_large":
self.model_version = "ESM-1b"
self._init_submodules_esm1b()
else:
self.model_version = "ESM-1"
self._init_submodules_esm1()
def _init_submodules_common(self):
self.embed_tokens = nn.Embedding(
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx
)
self.layers = nn.ModuleList(
[
TransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
add_bias_kv=(self.model_version != "ESM-1b"),
use_esm1b_layer_norm=(self.model_version == "ESM-1b"),
)
for _ in range(self.args.layers)
]
)
self.contact_head = ContactPredictionHead(
self.args.layers * self.args.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
def _init_submodules_esm1b(self):
self._init_submodules_common()
self.embed_scale = 1
self.embed_positions = LearnedPositionalEmbedding(
self.args.max_positions, self.args.embed_dim, self.padding_idx
)
self.emb_layer_norm_before = (
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.args.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)
def _init_submodules_esm1(self):
self._init_submodules_common()
self.embed_scale = math.sqrt(self.args.embed_dim)
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx)
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim)))
self.embed_out_bias = None
if self.args.final_bias:
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size))
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T
x = self.embed_scale * self.embed_tokens(tokens)
if getattr(self.args, "token_dropout", False):
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
x = x + self.embed_positions(tokens)
if self.model_version == "ESM-1b":
if self.emb_layer_norm_before:
x = self.emb_layer_norm_before(x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights = []
# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
for layer_idx, layer in enumerate(self.layers):
x, attn = layer(
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))
if self.model_version == "ESM-1b":
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
else:
x = F.linear(x, self.embed_out, bias=self.embed_out_bias)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if self.model_version == "ESM-1":
# ESM-1 models have an additional null-token for attention, which we remove
attentions = attentions[..., :-1]
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result
def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
@property
def num_layers(self):
return self.args.layers