TillCyrill
scripts from other repo
0da959e
raw
history blame
621 Bytes
from transforms import prot_graph_transform
class GNNTransformMD(object):
"""
Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph
"""
def __init__(self, edge_dist_cutoff=4.5):
"""
Args:
edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5.
"""
self.edge_dist_cutoff = edge_dist_cutoff
def __call__(self, item):
item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff)
return item['atoms_protein']