# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import biotite.structure import numpy as np import torch from typing import Sequence, Tuple, List from esm.inverse_folding.util import ( load_structure, extract_coords_from_structure, load_coords, get_sequence_loss, get_encoder_output, ) def extract_coords_from_complex(structure: biotite.structure.AtomArray): """ Args: structure: biotite AtomArray Returns: Tuple (coords_list, seq_list) - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain - seqs: Dictionary mapping chain ids to native sequences of each chain """ coords = {} seqs = {} all_chains = biotite.structure.get_chains(structure) for chain_id in all_chains: chain = structure[structure.chain_id == chain_id] coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain) return coords, seqs def load_complex_coords(fpath, chains): """ Args: fpath: filepath to either pdb or cif file chains: the chain ids (the order matters for autoregressive model) Returns: Tuple (coords_list, seq_list) - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain - seqs: Dictionary mapping chain ids to native sequences of each chain """ structure = load_structure(fpath, chains) return extract_coords_from_complex(structure) def _concatenate_coords(coords, target_chain_id, padding_length=10): """ Args: coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain target_chain_id: The chain id to sample sequences for padding_length: Length of padding between concatenated chains Returns: Tuple (coords, seq) - coords is an L x 3 x 3 array for N, CA, C coordinates, a concatenation of the chains with padding in between - seq is the extracted sequence, with padding tokens inserted between the concatenated chains """ pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32) # For best performance, put the target chain first in concatenation. coords_list = [coords[target_chain_id]] for chain_id in coords: if chain_id == target_chain_id: continue coords_list.append(pad_coords) coords_list.append(coords[chain_id]) coords_concatenated = np.concatenate(coords_list, axis=0) return coords_concatenated def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1., padding_length=10): """ Samples sequence for one chain in a complex. Args: model: An instance of the GVPTransformer model coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain target_chain_id: The chain id to sample sequences for padding_length: padding length in between chains Returns: Sampled sequence for the target chain """ target_chain_len = coords[target_chain_id].shape[0] all_coords = _concatenate_coords(coords, target_chain_id) device = next(model.parameters()).device # Supply padding tokens for other chains to avoid unused sampling for speed padding_pattern = [''] * all_coords.shape[0] for i in range(target_chain_len): padding_pattern[i] = '' sampled = model.sample(all_coords, partial_seq=padding_pattern, temperature=temperature, device=device) sampled = sampled[:target_chain_len] return sampled def score_sequence_in_complex(model, alphabet, coords, target_chain_id, target_seq, padding_length=10): """ Scores sequence for one chain in a complex. Args: model: An instance of the GVPTransformer model alphabet: Alphabet for the model coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain target_chain_id: The chain id to sample sequences for target_seq: Target sequence for the target chain for scoring. padding_length: padding length in between chains Returns: Tuple (ll_fullseq, ll_withcoord) - ll_fullseq: Average log-likelihood over the full target chain - ll_withcoord: Average log-likelihood in target chain excluding those residues without coordinates """ all_coords = _concatenate_coords(coords, target_chain_id) loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords, target_seq) ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum( ~target_padding_mask) # Also calculate average when excluding masked portions coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2)) ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) return ll_fullseq, ll_withcoord def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id): """ Args: model: An instance of the GVPTransformer model alphabet: Alphabet for the model coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C coordinates representing the backbone of each chain target_chain_id: The chain id to sample sequences for Returns: Dictionary mapping chain id to encoder output for each chain """ all_coords = _concatenate_coords(coords, target_chain_id) all_rep = get_encoder_output(model, alphabet, all_coords) target_chain_len = coords[target_chain_id].shape[0] return all_rep[:target_chain_len]