File size: 621 Bytes
0da959e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transforms import prot_graph_transform
        
class GNNTransformMD(object):
    """
    Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph
    """

    def __init__(self, edge_dist_cutoff=4.5):
        """

        Args:
            edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5.
        """
        self.edge_dist_cutoff = edge_dist_cutoff 

    def __call__(self, item):
        item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff)
        return item['atoms_protein']