Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# import argparse | |
# from typing import Any, Dict, List, Optional, Tuple, NamedTuple | |
import torch | |
from torch import nn | |
# from torch import Tensor | |
import torch.nn.functional as F | |
# from scipy.spatial import transform | |
# | |
# from esm.data import Alphabet | |
# from .features import DihedralFeatures | |
# from .gvp_encoder import GVPEncoder | |
# from .gvp_utils import unflatten_graph | |
print("gvp1_transformer") | |
from .gvp_transformer_encoder import GVPTransformerEncoder | |
print("gvp2_transformer") | |
from .transformer_decoder import TransformerDecoder | |
print("gvp3_transformer") | |
from .util import rotate, CoordBatchConverter | |
print("gvp4_transformer") | |
class GVPTransformerModel(nn.Module): | |
""" | |
GVP-Transformer inverse folding model. | |
Architecture: Geometric GVP-GNN as initial layers, followed by | |
sequence-to-sequence Transformer encoder and decoder. | |
""" | |
def __init__(self, args, alphabet): | |
super().__init__() | |
encoder_embed_tokens = self.build_embedding( | |
args, alphabet, args.encoder_embed_dim, | |
) | |
decoder_embed_tokens = self.build_embedding( | |
args, alphabet, args.decoder_embed_dim, | |
) | |
encoder = self.build_encoder(args, alphabet, encoder_embed_tokens) | |
decoder = self.build_decoder(args, alphabet, decoder_embed_tokens) | |
self.args = args | |
self.encoder = encoder | |
self.decoder = decoder | |
def build_encoder(cls, args, src_dict, embed_tokens): | |
encoder = GVPTransformerEncoder(args, src_dict, embed_tokens) | |
return encoder | |
def build_decoder(cls, args, tgt_dict, embed_tokens): | |
decoder = TransformerDecoder( | |
args, | |
tgt_dict, | |
embed_tokens, | |
) | |
return decoder | |
def build_embedding(cls, args, dictionary, embed_dim): | |
num_embeddings = len(dictionary) | |
padding_idx = dictionary.padding_idx | |
emb = nn.Embedding(num_embeddings, embed_dim, padding_idx) | |
nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5) | |
nn.init.constant_(emb.weight[padding_idx], 0) | |
return emb | |
def forward( | |
self, | |
coords, | |
padding_mask, | |
confidence, | |
prev_output_tokens, | |
return_all_hiddens: bool = False, | |
features_only: bool = False, | |
): | |
encoder_out = self.encoder(coords, padding_mask, confidence, | |
return_all_hiddens=return_all_hiddens) | |
logits, extra = self.decoder( | |
prev_output_tokens, | |
encoder_out=encoder_out, | |
features_only=features_only, | |
return_all_hiddens=return_all_hiddens, | |
) | |
return logits, extra | |
def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None): | |
""" | |
Samples sequences based on multinomial sampling (no beam search). | |
Args: | |
coords: L x 3 x 3 list representing one backbone | |
partial_seq: Optional, partial sequence with mask tokens if part of | |
the sequence is known | |
temperature: sampling temperature, use low temperature for higher | |
sequence recovery and high temperature for higher diversity | |
confidence: optional length L list of confidence scores for coordinates | |
""" | |
L = len(coords) | |
# Convert to batch format | |
batch_converter = CoordBatchConverter(self.decoder.dictionary) | |
batch_coords, confidence, _, _, padding_mask = ( | |
batch_converter([(coords, confidence, None)], device=device) | |
) | |
# Start with prepend token | |
mask_idx = self.decoder.dictionary.get_idx('<mask>') | |
sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int) | |
sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>') | |
if partial_seq is not None: | |
for i, c in enumerate(partial_seq): | |
sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c) | |
# Save incremental states for faster sampling | |
incremental_state = dict() | |
# Run encoder only once | |
encoder_out = self.encoder(batch_coords, padding_mask, confidence) | |
# Make sure all tensors are on the same device if a GPU is present | |
if device: | |
sampled_tokens = sampled_tokens.to(device) | |
# Decode one token at a time | |
for i in range(1, L+1): | |
logits, _ = self.decoder( | |
sampled_tokens[:, :i], | |
encoder_out, | |
incremental_state=incremental_state, | |
) | |
logits = logits[0].transpose(0, 1) | |
logits /= temperature | |
probs = F.softmax(logits, dim=-1) | |
if sampled_tokens[0, i] == mask_idx: | |
sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1) | |
sampled_seq = sampled_tokens[0, 1:] | |
# Convert back to string via lookup | |
return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]), encoder_out | |