File size: 6,221 Bytes
0da959e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import scipy.spatial as ss
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_sparse import coalesce

atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'}
        
ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23}


def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'):
    r"""
    Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.

    :param df: Protein structure in dataframe format.
    :type df: pandas.DataFrame
    :param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
    :type node_col: str, optional
    :param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features. 
        Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
    :type allowable_feats: list, optional
    :param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
    :type edge_dist_cutoff: float, optional

    :return: tuple containing

        - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.

        - edges (torch.LongTensor): Edges in COO format

        - edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.

        - node_pos (torch.FloatTensor): x-y-z coordinates of each node
    :rtype: Tuple
    """ 

    allowable_feats = atom_mapping
    
    try : 
        node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
        kd_tree = ss.KDTree(node_pos)
        edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
        edges = torch.LongTensor(edge_tuples).t().contiguous()
        edges = to_undirected(edges)
    except:
        print(f"Problem with PDB Id is {item['id']}")
        
    node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]])
    edge_weights = torch.FloatTensor(
        [1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)

    
    return node_feats, edges, edge_weights, node_pos


def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True):
    """
    Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
    :param df: Molecule structure in dataframe format
    :type mol: pandas.DataFrame
    :param bonds: Molecule structure in dataframe format
    :type bonds: pandas.DataFrame
    :param allowable_atoms: List containing allowable atom types
    :type allowable_atoms: list[str], optional
    :return: Tuple containing \n
        - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
        - edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
        - edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
        - node_pos (torch.FloatTensor): x-y-z coordinates of each node.
    """
    if allowable_atoms is None:
        allowable_atoms = ligand_atoms_mapping
    node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
    
    if bonds is not None:
        N = df.shape[0]
        bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}
        bond_data = torch.FloatTensor(bonds)
        edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
        edge_index = edge_tuples.t().long().contiguous()
        
        if onehot_edges:
            bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist()))
            edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
            edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
  
        else:
            edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
    else:
        kd_tree = ss.KDTree(node_pos)
        edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
        edge_index = torch.LongTensor(edge_tuples).t().contiguous()
        edge_index = to_undirected(edge_index)
        edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
        edge_attr = edge_attr.unsqueeze(1)
    
    node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']])    
    
    return node_feats, edge_index, edge_attr, node_pos


def one_of_k_encoding_unk_indices(x, allowable_set):
    """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
    one_hot_encoding = [0] * len(allowable_set)
    if x in allowable_set:
        one_hot_encoding[x] = 1
    else:
        one_hot_encoding[-1] = 1
    return one_hot_encoding

def one_of_k_encoding_unk_indices_qm(x, allowable_set):
    """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
    one_hot_encoding = [0] * (len(allowable_set)+1)
    if x in allowable_set:
        one_hot_encoding[allowable_set[x]] = 1
    else:
        one_hot_encoding[-1] = 1
    return one_hot_encoding