Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .multihead_attention import MultiheadAttention # noqa | |
from .axial_attention import ColumnSelfAttention, RowSelfAttention | |
def gelu(x): | |
"""Implementation of the gelu activation function. | |
For information: OpenAI GPT's gelu is slightly different | |
(and gives slightly different results): | |
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
""" | |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
def symmetrize(x): | |
"Make layer symmetric in final two dimensions, used for contact prediction." | |
return x + x.transpose(-1, -2) | |
def apc(x): | |
"Perform average product correct, used for contact prediction." | |
a1 = x.sum(-1, keepdims=True) | |
a2 = x.sum(-2, keepdims=True) | |
a12 = x.sum((-1, -2), keepdims=True) | |
avg = a1 * a2 | |
avg.div_(a12) # in-place to reduce memory | |
normalized = x - avg | |
return normalized | |
class ESM1LayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-12, affine=True): | |
"""Construct a layernorm layer in the TF style (eps inside the sqrt).""" | |
super().__init__() | |
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) | |
self.eps = eps | |
self.affine = bool(affine) | |
if self.affine: | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
else: | |
self.weight, self.bias = None, None | |
def forward(self, x): | |
dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) | |
means = x.mean(dims, keepdim=True) | |
x_zeromean = x - means | |
variances = x_zeromean.pow(2).mean(dims, keepdim=True) | |
x = x_zeromean / torch.sqrt(variances + self.eps) | |
if self.affine: | |
x = (self.weight * x) + self.bias | |
return x | |
try: | |
from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |
class ESM1bLayerNorm(_FusedLayerNorm): | |
def forward(self, x): | |
if not x.is_cuda: | |
return super().forward(x) | |
else: | |
with torch.cuda.device(x.device): | |
return super().forward(x) | |
except ImportError: | |
from torch.nn import LayerNorm as ESM1bLayerNorm | |
class TransformerLayer(nn.Module): | |
"""Transformer layer block.""" | |
def __init__( | |
self, | |
embed_dim, | |
ffn_embed_dim, | |
attention_heads, | |
add_bias_kv=True, | |
use_esm1b_layer_norm=False, | |
use_rotary_embeddings: bool = False, | |
): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.ffn_embed_dim = ffn_embed_dim | |
self.attention_heads = attention_heads | |
self.use_rotary_embeddings = use_rotary_embeddings | |
self._init_submodules(add_bias_kv, use_esm1b_layer_norm) | |
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): | |
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm | |
self.self_attn = MultiheadAttention( | |
self.embed_dim, | |
self.attention_heads, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=False, | |
use_rotary_embeddings=self.use_rotary_embeddings, | |
) | |
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) | |
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) | |
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) | |
self.final_layer_norm = BertLayerNorm(self.embed_dim) | |
def forward( | |
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False | |
): | |
residual = x | |
x = self.self_attn_layer_norm(x) | |
x, attn = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=self_attn_padding_mask, | |
need_weights=True, | |
need_head_weights=need_head_weights, | |
attn_mask=self_attn_mask, | |
) | |
x = residual + x | |
residual = x | |
x = self.final_layer_norm(x) | |
x = gelu(self.fc1(x)) | |
x = self.fc2(x) | |
x = residual + x | |
return x, attn | |
class AxialTransformerLayer(nn.Module): | |
"""Implements an Axial MSA Transformer block.""" | |
def __init__( | |
self, | |
embedding_dim: int = 768, | |
ffn_embedding_dim: int = 3072, | |
num_attention_heads: int = 8, | |
dropout: float = 0.1, | |
attention_dropout: float = 0.1, | |
activation_dropout: float = 0.1, | |
max_tokens_per_msa: int = 2**14, | |
) -> None: | |
super().__init__() | |
# Initialize parameters | |
self.embedding_dim = embedding_dim | |
self.dropout_prob = dropout | |
row_self_attention = RowSelfAttention( | |
embedding_dim, | |
num_attention_heads, | |
dropout=dropout, | |
max_tokens_per_msa=max_tokens_per_msa, | |
) | |
column_self_attention = ColumnSelfAttention( | |
embedding_dim, | |
num_attention_heads, | |
dropout=dropout, | |
max_tokens_per_msa=max_tokens_per_msa, | |
) | |
feed_forward_layer = FeedForwardNetwork( | |
embedding_dim, | |
ffn_embedding_dim, | |
activation_dropout=activation_dropout, | |
max_tokens_per_msa=max_tokens_per_msa, | |
) | |
self.row_self_attention = self.build_residual(row_self_attention) | |
self.column_self_attention = self.build_residual(column_self_attention) | |
self.feed_forward_layer = self.build_residual(feed_forward_layer) | |
def build_residual(self, layer: nn.Module): | |
return NormalizedResidualBlock( | |
layer, | |
self.embedding_dim, | |
self.dropout_prob, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
self_attn_mask: Optional[torch.Tensor] = None, | |
self_attn_padding_mask: Optional[torch.Tensor] = None, | |
need_head_weights: bool = False, | |
): | |
""" | |
LayerNorm is applied either before or after the self-attention/ffn | |
modules similar to the original Transformer implementation. | |
""" | |
x, row_attn = self.row_self_attention( | |
x, | |
self_attn_mask=self_attn_mask, | |
self_attn_padding_mask=self_attn_padding_mask, | |
) | |
x, column_attn = self.column_self_attention( | |
x, | |
self_attn_mask=self_attn_mask, | |
self_attn_padding_mask=self_attn_padding_mask, | |
) | |
x = self.feed_forward_layer(x) | |
if need_head_weights: | |
return x, column_attn, row_attn | |
else: | |
return x | |
class LearnedPositionalEmbedding(nn.Embedding): | |
""" | |
This module learns positional embeddings up to a fixed maximum size. | |
Padding ids are ignored by either offsetting based on padding_idx | |
or by setting padding_idx to None and ensuring that the appropriate | |
position ids are passed to the forward function. | |
""" | |
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
if padding_idx is not None: | |
num_embeddings_ = num_embeddings + padding_idx + 1 | |
else: | |
num_embeddings_ = num_embeddings | |
super().__init__(num_embeddings_, embedding_dim, padding_idx) | |
self.max_positions = num_embeddings | |
def forward(self, input: torch.Tensor): | |
"""Input is expected to be of size [bsz x seqlen].""" | |
if input.size(1) > self.max_positions: | |
raise ValueError( | |
f"Sequence length {input.size(1)} above maximum " | |
f" sequence length of {self.max_positions}" | |
) | |
mask = input.ne(self.padding_idx).int() | |
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx | |
return F.embedding( | |
positions, | |
self.weight, | |
self.padding_idx, | |
self.max_norm, | |
self.norm_type, | |
self.scale_grad_by_freq, | |
self.sparse, | |
) | |
class SinusoidalPositionalEmbedding(nn.Module): | |
def __init__(self, embed_dim, padding_idx, learned=False): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.padding_idx = padding_idx | |
self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
self.weights = None | |
def forward(self, x): | |
bsz, seq_len = x.shape | |
max_pos = self.padding_idx + 1 + seq_len | |
if self.weights is None or max_pos > self.weights.size(0): | |
self.weights = self.get_embedding(max_pos) | |
self.weights = self.weights.type_as(self._float_tensor) | |
positions = self.make_positions(x) | |
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() | |
def make_positions(self, x): | |
mask = x.ne(self.padding_idx) | |
range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 | |
positions = range_buf.expand_as(x) | |
return positions * mask.long() + self.padding_idx * (1 - mask.long()) | |
def get_embedding(self, num_embeddings): | |
half_dim = self.embed_dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) | |
if self.embed_dim % 2 == 1: | |
# zero pad | |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
if self.padding_idx is not None: | |
emb[self.padding_idx, :] = 0 | |
return emb | |
class RobertaLMHead(nn.Module): | |
"""Head for masked language modeling.""" | |
def __init__(self, embed_dim, output_dim, weight): | |
super().__init__() | |
self.dense = nn.Linear(embed_dim, embed_dim) | |
self.layer_norm = ESM1bLayerNorm(embed_dim) | |
self.weight = weight | |
self.bias = nn.Parameter(torch.zeros(output_dim)) | |
def forward(self, features): | |
x = self.dense(features) | |
x = gelu(x) | |
x = self.layer_norm(x) | |
# project back to size of vocabulary with bias | |
x = F.linear(x, self.weight) + self.bias | |
return x | |
class ContactPredictionHead(nn.Module): | |
"""Performs symmetrization, apc, and computes a logistic regression on the output features""" | |
def __init__( | |
self, | |
in_features: int, | |
prepend_bos: bool, | |
append_eos: bool, | |
bias=True, | |
eos_idx: Optional[int] = None, | |
): | |
super().__init__() | |
self.in_features = in_features | |
self.prepend_bos = prepend_bos | |
self.append_eos = append_eos | |
if append_eos and eos_idx is None: | |
raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") | |
self.eos_idx = eos_idx | |
self.regression = nn.Linear(in_features, 1, bias) | |
self.activation = nn.Sigmoid() | |
def forward(self, tokens, attentions): | |
# remove eos token attentions | |
if self.append_eos: | |
eos_mask = tokens.ne(self.eos_idx).to(attentions) | |
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) | |
attentions = attentions * eos_mask[:, None, None, :, :] | |
attentions = attentions[..., :-1, :-1] | |
# remove cls token attentions | |
if self.prepend_bos: | |
attentions = attentions[..., 1:, 1:] | |
batch_size, layers, heads, seqlen, _ = attentions.size() | |
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) | |
# features: B x C x T x T | |
attentions = attentions.to( | |
self.regression.weight.device | |
) # attentions always float32, may need to convert to float16 | |
attentions = apc(symmetrize(attentions)) | |
attentions = attentions.permute(0, 2, 3, 1) | |
return self.activation(self.regression(attentions).squeeze(3)) | |
class NormalizedResidualBlock(nn.Module): | |
def __init__( | |
self, | |
layer: nn.Module, | |
embedding_dim: int, | |
dropout: float = 0.1, | |
): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.layer = layer | |
self.dropout_module = nn.Dropout( | |
dropout, | |
) | |
self.layer_norm = ESM1bLayerNorm(self.embedding_dim) | |
def forward(self, x, *args, **kwargs): | |
residual = x | |
x = self.layer_norm(x) | |
outputs = self.layer(x, *args, **kwargs) | |
if isinstance(outputs, tuple): | |
x, *out = outputs | |
else: | |
x = outputs | |
out = None | |
x = self.dropout_module(x) | |
x = residual + x | |
if out is not None: | |
return (x,) + tuple(out) | |
else: | |
return x | |
class FeedForwardNetwork(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
ffn_embedding_dim: int, | |
activation_dropout: float = 0.1, | |
max_tokens_per_msa: int = 2**14, | |
): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.ffn_embedding_dim = ffn_embedding_dim | |
self.max_tokens_per_msa = max_tokens_per_msa | |
self.activation_fn = nn.GELU() | |
self.activation_dropout_module = nn.Dropout( | |
activation_dropout, | |
) | |
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) | |
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) | |
def forward(self, x): | |
x = self.activation_fn(self.fc1(x)) | |
x = self.activation_dropout_module(x) | |
x = self.fc2(x) | |
return x | |