EdwardoSunny's picture
finished
85ab89d
raw
history blame
5.27 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# import argparse
# from typing import Any, Dict, List, Optional, Tuple, NamedTuple
import torch
from torch import nn
# from torch import Tensor
import torch.nn.functional as F
# from scipy.spatial import transform
#
# from esm.data import Alphabet
# from .features import DihedralFeatures
# from .gvp_encoder import GVPEncoder
# from .gvp_utils import unflatten_graph
print("gvp1_transformer")
from .gvp_transformer_encoder import GVPTransformerEncoder
print("gvp2_transformer")
from .transformer_decoder import TransformerDecoder
print("gvp3_transformer")
from .util import rotate, CoordBatchConverter
print("gvp4_transformer")
class GVPTransformerModel(nn.Module):
"""
GVP-Transformer inverse folding model.
Architecture: Geometric GVP-GNN as initial layers, followed by
sequence-to-sequence Transformer encoder and decoder.
"""
def __init__(self, args, alphabet):
super().__init__()
encoder_embed_tokens = self.build_embedding(
args, alphabet, args.encoder_embed_dim,
)
decoder_embed_tokens = self.build_embedding(
args, alphabet, args.decoder_embed_dim,
)
encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
self.args = args
self.encoder = encoder
self.decoder = decoder
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
return encoder
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = TransformerDecoder(
args,
tgt_dict,
embed_tokens,
)
return decoder
@classmethod
def build_embedding(cls, args, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.padding_idx
emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5)
nn.init.constant_(emb.weight[padding_idx], 0)
return emb
def forward(
self,
coords,
padding_mask,
confidence,
prev_output_tokens,
return_all_hiddens: bool = False,
features_only: bool = False,
):
encoder_out = self.encoder(coords, padding_mask, confidence,
return_all_hiddens=return_all_hiddens)
logits, extra = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return logits, extra
def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None):
"""
Samples sequences based on multinomial sampling (no beam search).
Args:
coords: L x 3 x 3 list representing one backbone
partial_seq: Optional, partial sequence with mask tokens if part of
the sequence is known
temperature: sampling temperature, use low temperature for higher
sequence recovery and high temperature for higher diversity
confidence: optional length L list of confidence scores for coordinates
"""
L = len(coords)
# Convert to batch format
batch_converter = CoordBatchConverter(self.decoder.dictionary)
batch_coords, confidence, _, _, padding_mask = (
batch_converter([(coords, confidence, None)], device=device)
)
# Start with prepend token
mask_idx = self.decoder.dictionary.get_idx('<mask>')
sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>')
if partial_seq is not None:
for i, c in enumerate(partial_seq):
sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c)
# Save incremental states for faster sampling
incremental_state = dict()
# Run encoder only once
encoder_out = self.encoder(batch_coords, padding_mask, confidence)
# Make sure all tensors are on the same device if a GPU is present
if device:
sampled_tokens = sampled_tokens.to(device)
# Decode one token at a time
for i in range(1, L+1):
logits, _ = self.decoder(
sampled_tokens[:, :i],
encoder_out,
incremental_state=incremental_state,
)
logits = logits[0].transpose(0, 1)
logits /= temperature
probs = F.softmax(logits, dim=-1)
if sampled_tokens[0, i] == mask_idx:
sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
sampled_seq = sampled_tokens[0, 1:]
# Convert back to string via lookup
return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]), encoder_out