TillCyrill commited on
Commit
0da959e
1 Parent(s): 1d6ab1f

scripts from other repo

Browse files
Files changed (9) hide show
  1. 11GS.pdb +0 -0
  2. Dockerfile +9 -0
  3. MDmodel.py +42 -0
  4. app.py +150 -0
  5. best_weights_rep0.pt +3 -0
  6. graph.py +121 -0
  7. inference_for_md.hdf5 +0 -0
  8. transformMD.py +21 -0
  9. transforms.py +46 -0
11GS.pdb ADDED
The diff for this file is too large to render. See raw diff
 
Dockerfile CHANGED
@@ -41,6 +41,15 @@ ENV AMBERHOME="/usr/bin/amber22/"
41
  ENV PATH="$AMBERHOME/bin:$PATH"
42
  ENV PYTHONPATH="$AMBERHOME/lib/python3.10/site-packages"
43
 
 
 
 
 
 
 
 
 
 
44
  RUN adduser -u 5678 --disabled-password --gecos "" appuser && chown -R appuser .
45
  USER appuser
46
  CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
41
  ENV PATH="$AMBERHOME/bin:$PATH"
42
  ENV PYTHONPATH="$AMBERHOME/lib/python3.10/site-packages"
43
 
44
+
45
+ RUN useradd -m -u 1000 user
46
+ USER user
47
+ ENV HOME=/home/user \
48
+ PATH=/home/user/.local/bin:$PATH
49
+ WORKDIR $HOME/app
50
+ COPY --chown=user . $HOME/app
51
+
52
+
53
  RUN adduser -u 5678 --disabled-password --gecos "" appuser && chown -R appuser .
54
  USER appuser
55
  CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
MDmodel.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.nn import GCNConv
5
+
6
+ class GNN_MD(torch.nn.Module):
7
+ def __init__(self, num_features, hidden_dim):
8
+ super(GNN_MD, self).__init__()
9
+ self.conv1 = GCNConv(num_features, hidden_dim)
10
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
11
+ self.conv2 = GCNConv(hidden_dim, hidden_dim*2)
12
+ self.bn2 = nn.BatchNorm1d(hidden_dim*2)
13
+ self.conv3 = GCNConv(hidden_dim*2, hidden_dim*4)
14
+ self.bn3 = nn.BatchNorm1d(hidden_dim*4)
15
+ self.conv4 = GCNConv(hidden_dim*4, hidden_dim*4)
16
+ self.bn4 = nn.BatchNorm1d(hidden_dim*4)
17
+ self.conv5 = GCNConv(hidden_dim*4, hidden_dim*8)
18
+ self.bn5 = nn.BatchNorm1d(hidden_dim*8)
19
+ self.fc1 = nn.Linear(hidden_dim*8, hidden_dim*4)
20
+ self.fc2 = nn.Linear(hidden_dim*4, 1)
21
+
22
+
23
+ def forward(self, data):
24
+ x = self.conv1(data.x, data.edge_index, data.edge_attr.view(-1))
25
+ x = F.relu(x)
26
+ x = self.bn1(x)
27
+ x = self.conv2(x, data.edge_index, data.edge_attr.view(-1))
28
+ x = F.relu(x)
29
+ x = self.bn2(x)
30
+ x = self.conv3(x, data.edge_index, data.edge_attr.view(-1))
31
+ x = F.relu(x)
32
+ x = self.bn3(x)
33
+ x = self.conv4(x, data.edge_index, data.edge_attr.view(-1))
34
+ x = self.bn4(x)
35
+ x = F.relu(x)
36
+ x = self.conv5(x, data.edge_index, data.edge_attr.view(-1))
37
+ x = self.bn5(x)
38
+ #x = global_add_pool(x, x.batch)
39
+ x = F.relu(x)
40
+ x = F.relu(self.fc1(x))
41
+ x = F.dropout(x, p=0.25)
42
+ return self.fc2(x).view(-1)
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import py3Dmol
4
+
5
+ from Bio.PDB import *
6
+
7
+ import numpy as np
8
+ from Bio.PDB import PDBParser
9
+ import pandas as pd
10
+ import torch
11
+ import os
12
+ from MDmodel import GNN_MD
13
+ import h5py
14
+ from transformMD import GNNTransformMD
15
+
16
+ # JavaScript functions
17
+ resid_hover = """function(atom,viewer) {{
18
+ if(!atom.label) {{
19
+ atom.label = viewer.addLabel('{0}:'+atom.atom+atom.serial,
20
+ {{position: atom, backgroundColor: 'mintcream', fontColor:'black'}});
21
+ }}
22
+ }}"""
23
+ hover_func = """
24
+ function(atom,viewer) {
25
+ if(!atom.label) {
26
+ atom.label = viewer.addLabel(atom.interaction,
27
+ {position: atom, backgroundColor: 'black', fontColor:'white'});
28
+ }
29
+ }"""
30
+ unhover_func = """
31
+ function(atom,viewer) {
32
+ if(atom.label) {
33
+ viewer.removeLabel(atom.label);
34
+ delete atom.label;
35
+ }
36
+ }"""
37
+ atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
38
+
39
+ model = GNN_MD(11, 64)
40
+ state_dict = torch.load(
41
+ "best_weights_rep0.pt",
42
+ map_location=torch.device("cpu"),
43
+ )["model_state_dict"]
44
+ model.load_state_dict(state_dict)
45
+ model = model.to('cpu')
46
+ model.eval()
47
+
48
+
49
+
50
+ def get_pdb(pdb_code="", filepath=""):
51
+ try:
52
+ return filepath.name
53
+ except AttributeError as e:
54
+ if pdb_code is None or pdb_code == "":
55
+ return None
56
+ else:
57
+ os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
58
+ return f"{pdb_code}.pdb"
59
+
60
+
61
+ def get_offset(pdb):
62
+ pdb_multiline = pdb.split("\n")
63
+ for line in pdb_multiline:
64
+ if line.startswith("ATOM"):
65
+ return int(line[22:27])
66
+
67
+
68
+ def predict(pdb_code, pdb_file):
69
+ #path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
70
+
71
+ #pdb = open(path_to_pdb, "r").read()
72
+ # switch to misato env if not running from container
73
+ mdh5_file = "inference_for_md.hdf5"
74
+ md_H5File = h5py.File(mdh5_file)
75
+
76
+ column_names = ["x", "y", "z", "element"]
77
+ atoms_protein = pd.DataFrame(columns = column_names)
78
+ cutoff = md_H5File["11GS"]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms
79
+
80
+ atoms_protein["x"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 0]
81
+ atoms_protein["y"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 1]
82
+ atoms_protein["z"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 2]
83
+
84
+ atoms_protein["element"] = md_H5File["11GS"]["atoms_element"][:][:cutoff]
85
+
86
+ item = {}
87
+ item["scores"] = 0
88
+ item["id"] = "11GS"
89
+ item["atoms_protein"] = atoms_protein
90
+
91
+ transform = GNNTransformMD()
92
+ data_item = transform(item)
93
+ adaptability = model(data_item)
94
+ adaptability = adaptability.detach().numpy()
95
+
96
+ data = []
97
+
98
+
99
+ for i in range(adaptability.shape[0]):
100
+ data.append([i, atom_mapping(atoms_protein.iloc[i, atoms_protein.columns.get_loc("element")] - 1), atoms_protein.iloc[i, atoms_protein.columns.get_loc("x")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("y")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("z")],adaptability[i]])
101
+
102
+ topN = 100
103
+ topN_ind = np.argsort(adaptability)[::-1][:topN]
104
+
105
+ pdb = open(pdb_file.name, "r").read()
106
+
107
+ view = py3Dmol.view(width=600, height=400)
108
+ view.setBackgroundColor('black')
109
+ view.addModel(pdb, "pdb")
110
+ view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': 'turquoise'}}})
111
+
112
+ for i in range(topN):
113
+ view.addSphere({'center':{'x':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("x")], 'y':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("y")],'z':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("z")]},'radius':adaptability[topN_ind[i]]/1.5,'color':'orange','alpha':0.75})
114
+
115
+ view.zoomTo()
116
+
117
+ output = view._make_html().replace("'", '"')
118
+
119
+ x = f"""<!DOCTYPE html><html> {output} </html>""" # do not use ' in this input
120
+ return f"""<iframe style="width: 100%; height:420px" name="result" allow="midi; geolocation; microphone; camera;
121
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
122
+ allow-scripts allow-same-origin allow-popups
123
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
124
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", pd.DataFrame(data, columns=['index','element','x','y','z','Adaptability'])
125
+
126
+
127
+ callback = gr.CSVLogger()
128
+
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown("# Protein Adaptability Prediction")
131
+
132
+ #text_input = gr.Textbox()
133
+ #text_output = gr.Textbox()
134
+ #text_button = gr.Button("Flip")
135
+ inp = gr.Textbox(placeholder="PDB Code or upload file below", label="Input structure")
136
+ pdb_file = gr.File(label="PDB File Upload")
137
+ #with gr.Row():
138
+ # helix = gr.ColorPicker(label="helix")
139
+ # sheet = gr.ColorPicker(label="sheet")
140
+ # loop = gr.ColorPicker(label="loop")
141
+ single_btn = gr.Button(label="Run")
142
+ with gr.Row():
143
+ html = gr.HTML()
144
+ with gr.Row():
145
+ dataframe = gr.Dataframe()
146
+
147
+ single_btn.click(fn=predict, inputs=[inp, pdb_file], outputs=[html, dataframe])
148
+
149
+
150
+ demo.launch(debug=True)
best_weights_rep0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6b9499f2dc7b16eb3c669d2c6c26e11a53e21047264d3eb0cdda6bbc1d17f91
3
+ size 4517600
graph.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.spatial as ss
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch_geometric.utils import to_undirected
6
+ from torch_sparse import coalesce
7
+
8
+ atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
9
+ residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'}
10
+
11
+ ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23}
12
+
13
+
14
+ def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'):
15
+ r"""
16
+ Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.
17
+
18
+ :param df: Protein structure in dataframe format.
19
+ :type df: pandas.DataFrame
20
+ :param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
21
+ :type node_col: str, optional
22
+ :param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features.
23
+ Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
24
+ :type allowable_feats: list, optional
25
+ :param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
26
+ :type edge_dist_cutoff: float, optional
27
+
28
+ :return: tuple containing
29
+
30
+ - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.
31
+
32
+ - edges (torch.LongTensor): Edges in COO format
33
+
34
+ - edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.
35
+
36
+ - node_pos (torch.FloatTensor): x-y-z coordinates of each node
37
+ :rtype: Tuple
38
+ """
39
+
40
+ allowable_feats = atom_mapping
41
+
42
+ try :
43
+ node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
44
+ kd_tree = ss.KDTree(node_pos)
45
+ edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
46
+ edges = torch.LongTensor(edge_tuples).t().contiguous()
47
+ edges = to_undirected(edges)
48
+ except:
49
+ print(f"Problem with PDB Id is {item['id']}")
50
+
51
+ node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]])
52
+ edge_weights = torch.FloatTensor(
53
+ [1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)
54
+
55
+
56
+ return node_feats, edges, edge_weights, node_pos
57
+
58
+
59
+ def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True):
60
+ """
61
+ Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
62
+ :param df: Molecule structure in dataframe format
63
+ :type mol: pandas.DataFrame
64
+ :param bonds: Molecule structure in dataframe format
65
+ :type bonds: pandas.DataFrame
66
+ :param allowable_atoms: List containing allowable atom types
67
+ :type allowable_atoms: list[str], optional
68
+ :return: Tuple containing \n
69
+ - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
70
+ - edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
71
+ - edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
72
+ - node_pos (torch.FloatTensor): x-y-z coordinates of each node.
73
+ """
74
+ if allowable_atoms is None:
75
+ allowable_atoms = ligand_atoms_mapping
76
+ node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
77
+
78
+ if bonds is not None:
79
+ N = df.shape[0]
80
+ bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}
81
+ bond_data = torch.FloatTensor(bonds)
82
+ edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
83
+ edge_index = edge_tuples.t().long().contiguous()
84
+
85
+ if onehot_edges:
86
+ bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist()))
87
+ edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
88
+ edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
89
+
90
+ else:
91
+ edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
92
+ else:
93
+ kd_tree = ss.KDTree(node_pos)
94
+ edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
95
+ edge_index = torch.LongTensor(edge_tuples).t().contiguous()
96
+ edge_index = to_undirected(edge_index)
97
+ edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
98
+ edge_attr = edge_attr.unsqueeze(1)
99
+
100
+ node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']])
101
+
102
+ return node_feats, edge_index, edge_attr, node_pos
103
+
104
+
105
+ def one_of_k_encoding_unk_indices(x, allowable_set):
106
+ """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
107
+ one_hot_encoding = [0] * len(allowable_set)
108
+ if x in allowable_set:
109
+ one_hot_encoding[x] = 1
110
+ else:
111
+ one_hot_encoding[-1] = 1
112
+ return one_hot_encoding
113
+
114
+ def one_of_k_encoding_unk_indices_qm(x, allowable_set):
115
+ """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
116
+ one_hot_encoding = [0] * (len(allowable_set)+1)
117
+ if x in allowable_set:
118
+ one_hot_encoding[allowable_set[x]] = 1
119
+ else:
120
+ one_hot_encoding[-1] = 1
121
+ return one_hot_encoding
inference_for_md.hdf5 ADDED
Binary file (61.9 kB). View file
 
transformMD.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transforms import prot_graph_transform
2
+
3
+ class GNNTransformMD(object):
4
+ """
5
+ Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph
6
+ """
7
+
8
+ def __init__(self, edge_dist_cutoff=4.5):
9
+ """
10
+
11
+ Args:
12
+ edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5.
13
+ """
14
+ self.edge_dist_cutoff = edge_dist_cutoff
15
+
16
+ def __call__(self, item):
17
+ item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff)
18
+ return item['atoms_protein']
19
+
20
+
21
+
transforms.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.data import Data
3
+ from graph import prot_df_to_graph, mol_df_to_graph_for_qm
4
+
5
+ def prot_graph_transform(item, atom_keys, label_key, edge_dist_cutoff):
6
+ """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
7
+ Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.
8
+
9
+ :param item: Dataset item to transform
10
+ :type item: dict
11
+ :param atom_keys: list of keys to transform, where each key contains a dataframe of atoms, defaults to ['atoms']
12
+ :type atom_keys: list, optional
13
+ :param label_key: name of key containing labels, defaults to ['scores']
14
+ :type label_key: str, optional
15
+ :return: Transformed Dataset item
16
+ :rtype: dict
17
+ """
18
+
19
+ for key in atom_keys:
20
+ node_feats, edge_index, edge_feats, pos = prot_df_to_graph(item, item[key], edge_dist_cutoff)
21
+ item[key] = Data(node_feats, edge_index, edge_feats, y=torch.FloatTensor(item[label_key]), pos=pos, ids=item["id"])
22
+
23
+ return item
24
+
25
+ def mol_graph_transform_for_qm(item, atom_key, label_key, allowable_atoms, use_bonds, onehot_edges, edge_dist_cutoff):
26
+ """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
27
+ Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.
28
+
29
+ :param item: Dataset item to transform
30
+ :type item: dict
31
+ :param atom_key: name of key containing molecule structure as a dataframe, defaults to 'atoms'
32
+ :type atom_keys: list, optional
33
+ :param label_key: name of key containing labels, defaults to 'scores'
34
+ :type label_key: str, optional
35
+ :param use_bonds: whether to use molecular bond information for edges instead of distance. Assumes bonds are stored under 'bonds' key, defaults to False
36
+ :type use_bonds: bool, optional
37
+ :return: Transformed Dataset item
38
+ :rtype: dict
39
+ """
40
+
41
+ bonds = item['bonds'] if use_bonds else None
42
+
43
+ node_feats, edge_index, edge_feats, pos = mol_df_to_graph_for_qm(item[atom_key], bonds=bonds, onehot_edges=onehot_edges, allowable_atoms=allowable_atoms, edge_dist_cutoff=edge_dist_cutoff)
44
+ item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos)
45
+
46
+ return item