wenkai's picture
Upload 31 files
3f0529e verified
raw
history blame
14.8 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Portions of this file were adapted from the open source code for the following
# two papers:
#
# Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative
# models for graph-based protein design. Advances in Neural Information
# Processing Systems, 32.
#
# Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
# Learning from Protein Structure with Geometric Vector Perceptrons. In
# International Conference on Learning Representations.
#
# MIT License
#
# Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# ================================================================
# The below license applies to the portions of the code (parts of
# src/datasets.py and src/models.py) adapted from Ingraham, et al.
# ================================================================
#
# MIT License
#
# Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .gvp_utils import flatten_graph
from .gvp_modules import GVP, LayerNorm
from .util import normalize, norm, nan_to_num, rbf
class GVPInputFeaturizer(nn.Module):
@staticmethod
def get_node_features(coords, coord_mask, with_coord_mask=True):
# scalar features
node_scalar_features = GVPInputFeaturizer._dihedrals(coords)
if with_coord_mask:
node_scalar_features = torch.cat([
node_scalar_features,
coord_mask.float().unsqueeze(-1)
], dim=-1)
# vector features
X_ca = coords[:, :, 1]
orientations = GVPInputFeaturizer._orientations(X_ca)
sidechains = GVPInputFeaturizer._sidechains(coords)
node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
return node_scalar_features, node_vector_features
@staticmethod
def _orientations(X):
forward = normalize(X[:, 1:] - X[:, :-1])
backward = normalize(X[:, :-1] - X[:, 1:])
forward = F.pad(forward, [0, 0, 0, 1])
backward = F.pad(backward, [0, 0, 1, 0])
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
@staticmethod
def _sidechains(X):
n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2]
c, n = normalize(c - origin), normalize(n - origin)
bisector = normalize(c + n)
perp = normalize(torch.cross(c, n, dim=-1))
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
return vec
@staticmethod
def _dihedrals(X, eps=1e-7):
X = torch.flatten(X[:, :, :3], 1, 2)
bsz = X.shape[0]
dX = X[:, 1:] - X[:, :-1]
U = normalize(dX, dim=-1)
u_2 = U[:, :-2]
u_1 = U[:, 1:-1]
u_0 = U[:, 2:]
# Backbone normals
n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
# Angle between normals
cosD = torch.sum(n_2 * n_1, -1)
cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
# This scheme will remove phi[0], psi[-1], omega[-1]
D = F.pad(D, [1, 2])
D = torch.reshape(D, [bsz, -1, 3])
# Lift angle representations to the circle
D_features = torch.cat([torch.cos(D), torch.sin(D)], -1)
return D_features
@staticmethod
def _positional_embeddings(edge_index,
num_embeddings=None,
num_positional_embeddings=16,
period_range=[2, 1000]):
# From https://github.com/jingraham/neurips19-graph-protein-design
num_embeddings = num_embeddings or num_positional_embeddings
d = edge_index[0] - edge_index[1]
frequency = torch.exp(
torch.arange(0, num_embeddings, 2, dtype=torch.float32,
device=edge_index.device)
* -(np.log(10000.0) / num_embeddings)
)
angles = d.unsqueeze(-1) * frequency
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
return E
@staticmethod
def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8):
""" Pairwise euclidean distances """
bsz, maxlen = X.size(0), X.size(1)
coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2)
residue_mask = ~padding_mask
residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2)
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
D = coord_mask_2D * norm(dX, dim=-1)
# sorting preference: first those with coords, then among the residues that
# exist but are masked use distance in sequence as tie breaker, and then the
# residues that came from padding are last
seqpos = torch.arange(maxlen, device=X.device)
Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1)
D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + (
~residue_mask_2D) * (1e10)
if top_k_neighbors == -1:
D_neighbors = D_adjust
E_idx = seqpos.repeat(
*D_neighbors.shape[:-1], 1)
else:
# Identify k nearest neighbors (including self)
k = min(top_k_neighbors, X.size(1))
D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False)
coord_mask_neighbors = (D_neighbors < 5e7)
residue_mask_neighbors = (D_neighbors < 5e9)
return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors
class Normalize(nn.Module):
def __init__(self, features, epsilon=1e-6):
super(Normalize, self).__init__()
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.epsilon = epsilon
def forward(self, x, dim=-1):
mu = x.mean(dim, keepdim=True)
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
gain = self.gain
bias = self.bias
# Reshape
if dim != -1:
shape = [1] * len(mu.size())
shape[dim] = self.gain.size()[0]
gain = gain.view(shape)
bias = bias.view(shape)
return gain * (x - mu) / (sigma + self.epsilon) + bias
class DihedralFeatures(nn.Module):
def __init__(self, node_embed_dim):
""" Embed dihedral angle features. """
super(DihedralFeatures, self).__init__()
# 3 dihedral angles; sin and cos of each angle
node_in = 6
# Normalization and embedding
self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True)
self.norm_nodes = Normalize(node_embed_dim)
def forward(self, X):
""" Featurize coordinates as an attributed graph """
V = self._dihedrals(X)
V = self.node_embedding(V)
V = self.norm_nodes(V)
return V
@staticmethod
def _dihedrals(X, eps=1e-7, return_angles=False):
# First 3 coordinates are N, CA, C
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# This scheme will remove phi[0], psi[-1], omega[-1]
D = F.pad(D, (1,2), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
phi, psi, omega = torch.unbind(D,-1)
if return_angles:
return phi, psi, omega
# Lift angle representations to the circle
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
return D_features
class GVPGraphEmbedding(GVPInputFeaturizer):
def __init__(self, args):
super().__init__()
self.top_k_neighbors = args.top_k_neighbors
self.num_positional_embeddings = 16
self.remove_edges_without_coords = True
node_input_dim = (7, 3)
edge_input_dim = (34, 1)
node_hidden_dim = (args.node_hidden_dim_scalar,
args.node_hidden_dim_vector)
edge_hidden_dim = (args.edge_hidden_dim_scalar,
args.edge_hidden_dim_vector)
self.embed_node = nn.Sequential(
GVP(node_input_dim, node_hidden_dim, activations=(None, None)),
LayerNorm(node_hidden_dim, eps=1e-4)
)
self.embed_edge = nn.Sequential(
GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)),
LayerNorm(edge_hidden_dim, eps=1e-4)
)
self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar)
def forward(self, coords, coord_mask, padding_mask, confidence):
with torch.no_grad():
node_features = self.get_node_features(coords, coord_mask)
edge_features, edge_index = self.get_edge_features(
coords, coord_mask, padding_mask)
node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features)
edge_embeddings = self.embed_edge(edge_features)
rbf_rep = rbf(confidence, 0., 1.)
node_embeddings = (
node_embeddings_scalar + self.embed_confidence(rbf_rep),
node_embeddings_vector
)
node_embeddings, edge_embeddings, edge_index = flatten_graph(
node_embeddings, edge_embeddings, edge_index)
return node_embeddings, edge_embeddings, edge_index
def get_edge_features(self, coords, coord_mask, padding_mask):
X_ca = coords[:, :, 1]
# Get distances to the top k neighbors
E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist(
X_ca, coord_mask, padding_mask, self.top_k_neighbors)
# Flatten the graph to be batch size 1 for torch_geometric package
dest = E_idx
B, L, k = E_idx.shape[:3]
src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k)
# After flattening, [2, B, E]
edge_index = torch.stack([src, dest], dim=0).flatten(2, 3)
# After flattening, [B, E]
E_dist = E_dist.flatten(1, 2)
E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1)
E_residue_mask = E_residue_mask.flatten(1, 2)
# Calculate relative positional embeddings and distance RBF
pos_embeddings = GVPInputFeaturizer._positional_embeddings(
edge_index,
num_positional_embeddings=self.num_positional_embeddings,
)
D_rbf = rbf(E_dist, 0., 20.)
# Calculate relative orientation
X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2)
X_dest = torch.gather(
X_ca,
1,
edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3])
)
coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2)
coord_mask_dest = torch.gather(
coord_mask,
1,
edge_index[1, :, :].expand([B, L*k])
)
E_vectors = X_src - X_dest
# For the ones without coordinates, substitute in the average vector
E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1,
keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True)
E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask)
# Normalize and remove nans
edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1)
edge_v = normalize(E_vectors).unsqueeze(-2)
edge_s, edge_v = map(nan_to_num, (edge_s, edge_v))
# Also add indications of whether the coordinates are present
edge_s = torch.cat([
edge_s,
(~coord_mask_src).float().unsqueeze(-1),
(~coord_mask_dest).float().unsqueeze(-1),
], dim=-1)
edge_index[:, ~E_residue_mask] = -1
if self.remove_edges_without_coords:
edge_index[:, ~E_coord_mask.squeeze(-1)] = -1
return (edge_s, edge_v), edge_index.transpose(0, 1)