# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import typing as T from contextlib import ExitStack from dataclasses import dataclass import torch import torch.nn as nn from openfold.model.structure_module import StructureModule from esm.esmfold.v1.tri_self_attn_block import TriangularSelfAttentionBlock @dataclass class StructureModuleConfig: c_s: int = 384 c_z: int = 128 c_ipa: int = 16 c_resnet: int = 128 no_heads_ipa: int = 12 no_qk_points: int = 4 no_v_points: int = 8 dropout_rate: float = 0.1 no_blocks: int = 8 no_transition_layers: int = 1 no_resnet_blocks: int = 2 no_angles: int = 7 trans_scale_factor: int = 10 epsilon: float = 1e-8 inf: float = 1e5 @dataclass class FoldingTrunkConfig: _name: str = "FoldingTrunkConfig" num_blocks: int = 48 sequence_state_dim: int = 1024 pairwise_state_dim: int = 128 sequence_head_width: int = 32 pairwise_head_width: int = 32 position_bins: int = 32 dropout: float = 0 layer_drop: float = 0 cpu_grad_checkpoint: bool = False max_recycles: int = 4 chunk_size: T.Optional[int] = None structure_module: StructureModuleConfig = StructureModuleConfig() def get_axial_mask(mask): """ Helper to convert B x L mask of valid positions to axial mask used in row column attentions. Input: mask: B x L tensor of booleans Output: mask: B x L x L tensor of booleans """ if mask is None: return None assert len(mask.shape) == 2 batch_dim, seq_dim = mask.shape m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) m = m.reshape(batch_dim * seq_dim, seq_dim) return m class RelativePosition(nn.Module): def __init__(self, bins, pairwise_state_dim): super().__init__() self.bins = bins # Note an additional offset is used so that the 0th position # is reserved for masked pairs. self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim) def forward(self, residue_index, mask=None): """ Input: residue_index: B x L tensor of indices (dytpe=torch.long) mask: B x L tensor of booleans Output: pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings """ assert residue_index.dtype == torch.long if mask is not None: assert residue_index.shape == mask.shape diff = residue_index[:, None, :] - residue_index[:, :, None] diff = diff.clamp(-self.bins, self.bins) diff = diff + self.bins + 1 # Add 1 to adjust for padding index. if mask is not None: mask = mask[:, None, :] * mask[:, :, None] diff[mask == False] = 0 output = self.embedding(diff) return output class FoldingTrunk(nn.Module): def __init__(self, **kwargs): super().__init__() self.cfg = FoldingTrunkConfig(**kwargs) assert self.cfg.max_recycles > 0 c_s = self.cfg.sequence_state_dim c_z = self.cfg.pairwise_state_dim assert c_s % self.cfg.sequence_head_width == 0 assert c_z % self.cfg.pairwise_head_width == 0 block = TriangularSelfAttentionBlock self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z) self.blocks = nn.ModuleList( [ block( sequence_state_dim=c_s, pairwise_state_dim=c_z, sequence_head_width=self.cfg.sequence_head_width, pairwise_head_width=self.cfg.pairwise_head_width, dropout=self.cfg.dropout, ) for i in range(self.cfg.num_blocks) ] ) self.recycle_bins = 15 self.recycle_s_norm = nn.LayerNorm(c_s) self.recycle_z_norm = nn.LayerNorm(c_z) self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) self.recycle_disto.weight[0].detach().zero_() self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s) self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z) self.chunk_size = self.cfg.chunk_size def set_chunk_size(self, chunk_size): # This parameter means the axial attention will be computed # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). # It's equivalent to running a for loop over chunks of the dimension we're iterative over, # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks. self.chunk_size = chunk_size def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None): """ Inputs: seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues Output: predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object """ device = seq_feats.device s_s_0 = seq_feats s_z_0 = pair_feats if no_recycles is None: no_recycles = self.cfg.max_recycles else: assert no_recycles >= 0, "Number of recycles must not be negative." no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. def trunk_iter(s, z, residx, mask): z = z + self.pairwise_positional_embedding(residx, mask=mask) for block in self.blocks: s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) return s, z s_s = s_s_0 s_z = s_z_0 recycle_s = torch.zeros_like(s_s) recycle_z = torch.zeros_like(s_z) recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) assert no_recycles > 0 for recycle_idx in range(no_recycles): with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad(): # === Recycling === recycle_s = self.recycle_s_norm(recycle_s.detach()) recycle_z = self.recycle_z_norm(recycle_z.detach()) recycle_z += self.recycle_disto(recycle_bins.detach()) s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) # === Structure module === structure = self.structure_module( {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, true_aa, mask.float(), ) recycle_s = s_s recycle_z = s_z # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. recycle_bins = FoldingTrunk.distogram( structure["positions"][-1][:, :, :3], 3.375, 21.375, self.recycle_bins, ) assert isinstance(structure, dict) # type: ignore structure["s_s"] = s_s structure["s_z"] = s_z return structure @staticmethod def distogram(coords, min_bin, max_bin, num_bins): # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. boundaries = torch.linspace( min_bin, max_bin, num_bins - 1, device=coords.device, ) boundaries = boundaries**2 N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] # Infer CB coordinates. b = CA - N c = C - CA a = b.cross(c, dim=-1) CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] return bins